From 6cbe92315567c5fca3f96fa049c01a781bbf1338 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 27 Mar 2017 14:52:14 -0800 Subject: Added a new regression test to ensure control dependencies are properly processed Change: 151383375 --- .../core/grappler/optimizers/model_pruner_test.cc | 42 ++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index 53f868a6dd..6b90dba864 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -157,6 +157,48 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { EXPECT_EQ("^c", new_e.input(1)); } +TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) { + // Build a simple graph with a few trivially prunable ops. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + Output b = ops::AddN(s.WithOpName("b"), {a}); + Output c = ops::AddN(s.WithOpName("c"), {a}); + Output d = ops::Identity(s.WithOpName("d"), c); + Output e = ops::Identity(s.WithOpName("e"), d); + Output f = ops::AddN(s.WithOpName("f"), {e}); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Add a control dependency between b and d and another one between c and e. + // They should be properly forwarded. + EXPECT_EQ("d", item.graph.node(3).name()); + EXPECT_EQ("e", item.graph.node(4).name()); + *item.graph.mutable_node(3)->add_input() = "^b"; + *item.graph.mutable_node(4)->add_input() = "^c"; + + ModelPruner pruner; + GraphDef output; + Status status = pruner.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(4, output.node_size()); + const NodeDef& new_a = output.node(0); + EXPECT_EQ(NodeName(a.name()), new_a.name()); + const NodeDef& new_b = output.node(1); + EXPECT_EQ(NodeName(b.name()), new_b.name()); + const NodeDef& new_c = output.node(2); + EXPECT_EQ(NodeName(c.name()), new_c.name()); + const NodeDef& new_f = output.node(3); + EXPECT_EQ(NodeName(f.name()), new_f.name()); + + EXPECT_EQ(3, new_f.input_size()); + EXPECT_EQ(NodeName(c.name()), new_f.input(0)); + EXPECT_EQ("^b", new_f.input(1)); + EXPECT_EQ("^c", new_f.input(2)); +} + } // namespace } // namespace grappler } // namespace tensorflow -- cgit v1.2.3