diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index 57b3118245..6a297da52d 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -678,6 +678,50 @@ TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) { } } +TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape({})); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape({})); + auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y); + auto noop = + ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal)); + Output add = ops::Add( + s.WithOpName("z").WithControlDependencies({noop.operation}), x, y); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DependencyOptimizer optimizer; + GraphDef output; + item.fetch.push_back("z"); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "x") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "y") { + count++; + EXPECT_EQ("Placeholder", node.op()); + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "GreaterEqual") { + count++; + } else if (node.name() == "NoOp") { + count++; + } else if (node.name() == "z") { + count++; + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("y", node.input(1)); + } + } + EXPECT_EQ(3, count); +} + } // namespace } // namespace grappler } // namespace tensorflow |