diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/constant_folding_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 31 |
1 files changed, 20 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 32a691d3ee..a17ec733ea 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -81,26 +81,27 @@ TEST_F(ConstantFoldingTest, NeutralElement) { for (bool use_const : {true, false}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, - ops::Placeholder::Shape(TensorShape({1, 2}))); + ops::Placeholder::Shape(TensorShape({2, 2}))); Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, - ops::Placeholder::Shape(TensorShape({1, 2}))); - Output zeros = - !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x) - : ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f}, {1, 2}); + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x) + : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2}); Output zeros_broadcast = - ops::Const(s.WithOpName("zeros_broadcast"), {0.0f}, {1, 1}); - Output ones = !use_const - ? ops::OnesLike(s.WithOpName("ones"), x) - : ops::Const(s.WithOpName("ones"), {1.0f, 1.0f}, {1, 2}); + ops::Const(s.WithOpName("zeros_broadcast"), 0.0f, {1, 1}); + Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x) + : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2}); Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y); Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones); Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y); Output mul5 = ops::Mul(s.WithOpName("mul1"), x, zeros_broadcast); Output mul6 = ops::Mul(s.WithOpName("mul2"), zeros_broadcast, y); + Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros); + Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y); Output add1 = ops::Add(s.WithOpName("add1"), x, zeros); Output add2 = ops::Add(s.WithOpName("add2"), zeros, y); - Output addn = ops::AddN(s, {mul1, mul2, mul3, mul4, add1, add2}); + Output addn = + ops::AddN(s, {mul1, mul2, mul3, mul4, matmul1, matmul2, add1, add2}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -110,7 +111,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(14, output.node_size()); + EXPECT_EQ(16, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); const string& name = node.name(); @@ -132,6 +133,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("zeros", node.input(0)); EXPECT_EQ("^y", node.input(1)); } + } else if (name == "matmul1") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("zeros", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + } else if (name == "matmul2") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("zeros", node.input(0)); + EXPECT_EQ("^y", node.input(1)); } else if (name == "mul3") { EXPECT_EQ("Identity", node.op()); EXPECT_EQ("x", node.input(0)); |