diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-14 01:22:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-14 01:25:28 -0700 |
commit | 945efa4222a66977c03638086773c369c16d5c61 (patch) | |
tree | 10b5307041377cd6ce36a359f7ee3661b21513c8 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | |
parent | 6a581e1d7c28f5b8f487f2a91649d7e2866974f4 (diff) |
Make sure that same nodes are not optimized as part of multiple groups.
Replace recusrsion with iteration in AbsorbInputByOptimizedNodesGroup.
PiperOrigin-RevId: 192874364
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index e639812858..cb1f2ea732 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -105,6 +105,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_identity_transpose = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; + options.remove_negation = false; optimizer->options_ = options; } @@ -2069,20 +2070,20 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) { // a b c D a b NodeMap node_map(&output); - const NodeDef* mul1_node = node_map.GetNode("mul1"); + const NodeDef* mul1_node = node_map.GetNode("mul2"); ASSERT_NE(mul1_node, nullptr); EXPECT_EQ("a", mul1_node->input(0)); EXPECT_EQ("b", mul1_node->input(1)); - const NodeDef* mul2_node = node_map.GetNode("mul2"); + const NodeDef* mul2_node = node_map.GetNode("mul1"); ASSERT_NE(mul2_node, nullptr); - EXPECT_EQ("mul1", mul2_node->input(0)); + EXPECT_EQ("mul2", mul2_node->input(0)); EXPECT_EQ("c", mul2_node->input(1)); const NodeDef* mul3_node = node_map.GetNode("mul3"); ASSERT_NE(mul3_node, nullptr); EXPECT_EQ("D", mul3_node->input(0)); - EXPECT_EQ("mul2", mul3_node->input(1)); + EXPECT_EQ("mul1", mul3_node->input(1)); } } // namespace grappler |