aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-14 01:22:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-14 01:25:28 -0700
commit945efa4222a66977c03638086773c369c16d5c61 (patch)
tree10b5307041377cd6ce36a359f7ee3661b21513c8 /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent6a581e1d7c28f5b8f487f2a91649d7e2866974f4 (diff)
Make sure that same nodes are not optimized as part of multiple groups.
Replace recusrsion with iteration in AbsorbInputByOptimizedNodesGroup. PiperOrigin-RevId: 192874364
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e639812858..cb1f2ea732 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -105,6 +105,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_identity_transpose = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
+ options.remove_negation = false;
optimizer->options_ = options;
}
@@ -2069,20 +2070,20 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
// a b c D a b
NodeMap node_map(&output);
- const NodeDef* mul1_node = node_map.GetNode("mul1");
+ const NodeDef* mul1_node = node_map.GetNode("mul2");
ASSERT_NE(mul1_node, nullptr);
EXPECT_EQ("a", mul1_node->input(0));
EXPECT_EQ("b", mul1_node->input(1));
- const NodeDef* mul2_node = node_map.GetNode("mul2");
+ const NodeDef* mul2_node = node_map.GetNode("mul1");
ASSERT_NE(mul2_node, nullptr);
- EXPECT_EQ("mul1", mul2_node->input(0));
+ EXPECT_EQ("mul2", mul2_node->input(0));
EXPECT_EQ("c", mul2_node->input(1));
const NodeDef* mul3_node = node_map.GetNode("mul3");
ASSERT_NE(mul3_node, nullptr);
EXPECT_EQ("D", mul3_node->input(0));
- EXPECT_EQ("mul2", mul3_node->input(1));
+ EXPECT_EQ("mul1", mul3_node->input(1));
}
} // namespace grappler