aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2018-06-19 09:42:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 09:44:54 -0700
commita14de341d069387ff8c8a98ff73bf1e5782a5cae (patch)
tree95a0903eed8f1c4a7788bed776eeb36c596c5bb6 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent316fee40d4978db2f6abbb5ff35cf8d979bee93e (diff)
Automated g4 rollback of changelist 201069367
PiperOrigin-RevId: 201190626
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc26
1 files changed, 19 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index d0e6b04679..e1d55cdf5f 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2976,8 +2976,12 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
- Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
- Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
+ Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
+ Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
+ Output sn1 =
+ ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a);
+ Output sn2 =
+ ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1);
Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
Output id1 = ops::Identity(s.WithOpName("id1"), a);
Output id2 = ops::Identity(s.WithOpName("id2"), id1);
@@ -2993,24 +2997,32 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
EnableOnlyRemoveIdempotent(&optimizer);
OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(7, output.node_size());
+ EXPECT_EQ(11, output.node_size());
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "out1") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("sn1", node.input(0));
+ EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0));
+ found++;
+ } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") {
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("a", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
+ EXPECT_EQ("^ctrl2", node.input(2));
found++;
} else if (node.name() == "out2") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("id1", node.input(0));
+ EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0));
found++;
- } else if (node.name() == "sn1") {
+ } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") {
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("a", node.input(0));
found++;
}
}
- EXPECT_EQ(3, found);
+ EXPECT_EQ(4, found);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());