aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-03 13:00:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-03 13:40:29 -0700
commitceda30408f66a7eea86dc359164deb662d5a32d0 (patch)
treea20d71c9d126dca85b7e1588d8f661c13f3a1b6d /tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
parent775d1c03c1772c0c2e10e5884af8d9363cfdf314 (diff)
Enable unary chain hoisting optimization for concat/split/splitv by default.
PiperOrigin-RevId: 195297330
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc16
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index f903f53a35..d32743f3f2 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2320,16 +2320,16 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
EXPECT_NE(node.name(), "cos_exp_b2");
if (node.name() == "split1") {
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("axis", node.input(0));
EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1));
- EXPECT_EQ("^ctrl1", node.input(2));
found++;
}
if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
EXPECT_EQ("Sin", node.op());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
found++;
}
if (node.name() == "id_a") {
@@ -2349,8 +2349,11 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
}
if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
EXPECT_EQ("Exp", node.op());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(4, node.input_size());
EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
+ EXPECT_EQ("^ctrl2", node.input(2));
+ EXPECT_EQ("^ctrl3", node.input(3));
found++;
}
if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
@@ -2360,13 +2363,10 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
found++;
}
if (node.name() == "split2") {
- EXPECT_EQ(6, node.input_size());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0));
EXPECT_EQ("size_splits2", node.input(1));
EXPECT_EQ("axis", node.input(2));
- EXPECT_EQ("^ctrl1", node.input(3));
- EXPECT_EQ("^ctrl2", node.input(4));
- EXPECT_EQ("^ctrl3", node.input(5));
found++;
}
if (node.name() == "id_a2") {