aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-03-27 14:52:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 16:11:11 -0700
commit6cbe92315567c5fca3f96fa049c01a781bbf1338 (patch)
treefa21407debbf28795d25fd675e7f4aaaef2fbaa5
parent4f12b3b00fd3fb707f962477d45c59dad209e561 (diff)
Added a new regression test to ensure control dependencies are properly
processed Change: 151383375
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner_test.cc42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
index 53f868a6dd..6b90dba864 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
@@ -157,6 +157,48 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
EXPECT_EQ("^c", new_e.input(1));
}
+TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) {
+ // Build a simple graph with a few trivially prunable ops.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ Output b = ops::AddN(s.WithOpName("b"), {a});
+ Output c = ops::AddN(s.WithOpName("c"), {a});
+ Output d = ops::Identity(s.WithOpName("d"), c);
+ Output e = ops::Identity(s.WithOpName("e"), d);
+ Output f = ops::AddN(s.WithOpName("f"), {e});
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ // Add a control dependency between b and d and another one between c and e.
+ // They should be properly forwarded.
+ EXPECT_EQ("d", item.graph.node(3).name());
+ EXPECT_EQ("e", item.graph.node(4).name());
+ *item.graph.mutable_node(3)->add_input() = "^b";
+ *item.graph.mutable_node(4)->add_input() = "^c";
+
+ ModelPruner pruner;
+ GraphDef output;
+ Status status = pruner.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(4, output.node_size());
+ const NodeDef& new_a = output.node(0);
+ EXPECT_EQ(NodeName(a.name()), new_a.name());
+ const NodeDef& new_b = output.node(1);
+ EXPECT_EQ(NodeName(b.name()), new_b.name());
+ const NodeDef& new_c = output.node(2);
+ EXPECT_EQ(NodeName(c.name()), new_c.name());
+ const NodeDef& new_f = output.node(3);
+ EXPECT_EQ(NodeName(f.name()), new_f.name());
+
+ EXPECT_EQ(3, new_f.input_size());
+ EXPECT_EQ(NodeName(c.name()), new_f.input(0));
+ EXPECT_EQ("^b", new_f.input(1));
+ EXPECT_EQ("^c", new_f.input(2));
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow