diff options
author | Jingyue Wu <jingyue@google.com> | 2018-06-19 09:42:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-19 09:44:54 -0700 |
commit | a14de341d069387ff8c8a98ff73bf1e5782a5cae (patch) | |
tree | 95a0903eed8f1c4a7788bed776eeb36c596c5bb6 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 316fee40d4978db2f6abbb5ff35cf8d979bee93e (diff) |
Automated g4 rollback of changelist 201069367
PiperOrigin-RevId: 201190626
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d0e6b04679..e1d55cdf5f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2976,8 +2976,12 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 3.14f, {32}); - Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a); - Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1); + Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {}); + Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {}); + Output sn1 = + ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a); + Output sn2 = + ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1); Output out1 = ops::Identity(s.WithOpName("out1"), sn2); Output id1 = ops::Identity(s.WithOpName("id1"), a); Output id2 = ops::Identity(s.WithOpName("id2"), id1); @@ -2993,24 +2997,32 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { EnableOnlyRemoveIdempotent(&optimizer); OptimizeTwice(&optimizer, &item, &output); - EXPECT_EQ(7, output.node_size()); + EXPECT_EQ(11, output.node_size()); int found = 0; for (const NodeDef& node : output.node()) { if (node.name() == "out1") { EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("sn1", node.input(0)); + EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0)); + found++; + } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Snapshot", node.op()); + EXPECT_EQ("a", node.input(0)); + EXPECT_EQ("^ctrl1", node.input(1)); + EXPECT_EQ("^ctrl2", node.input(2)); found++; } else if (node.name() == "out2") { EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("id1", node.input(0)); + EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0)); found++; - } else if (node.name() == "sn1") { + } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") { + EXPECT_EQ("Identity", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("a", node.input(0)); found++; } } - EXPECT_EQ(3, found); + EXPECT_EQ(4, found); auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(tensors.size(), tensors_expected.size()); |