aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index 57b3118245..6a297da52d 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -678,6 +678,50 @@ TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
}
}
+TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ auto greaterequal = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto noop =
+ ops::NoOp(s.WithOpName("NoOp").WithControlDependencies(greaterequal));
+ Output add = ops::Add(
+ s.WithOpName("z").WithControlDependencies({noop.operation}), x, y);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ item.fetch.push_back("z");
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int count = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "y") {
+ count++;
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "GreaterEqual") {
+ count++;
+ } else if (node.name() == "NoOp") {
+ count++;
+ } else if (node.name() == "z") {
+ count++;
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
+ }
+ }
+ EXPECT_EQ(3, count);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow