aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/constant_folding_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/constant_folding_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc31
1 files changed, 20 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 32a691d3ee..a17ec733ea 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -81,26 +81,27 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
for (bool use_const : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
- ops::Placeholder::Shape(TensorShape({1, 2})));
+ ops::Placeholder::Shape(TensorShape({2, 2})));
Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
- ops::Placeholder::Shape(TensorShape({1, 2})));
- Output zeros =
- !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
- : ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f}, {1, 2});
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
+ : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2});
Output zeros_broadcast =
- ops::Const(s.WithOpName("zeros_broadcast"), {0.0f}, {1, 1});
- Output ones = !use_const
- ? ops::OnesLike(s.WithOpName("ones"), x)
- : ops::Const(s.WithOpName("ones"), {1.0f, 1.0f}, {1, 2});
+ ops::Const(s.WithOpName("zeros_broadcast"), 0.0f, {1, 1});
+ Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x)
+ : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2});
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
Output mul5 = ops::Mul(s.WithOpName("mul1"), x, zeros_broadcast);
Output mul6 = ops::Mul(s.WithOpName("mul2"), zeros_broadcast, y);
+ Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
+ Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
- Output addn = ops::AddN(s, {mul1, mul2, mul3, mul4, add1, add2});
+ Output addn =
+ ops::AddN(s, {mul1, mul2, mul3, mul4, matmul1, matmul2, add1, add2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -110,7 +111,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(14, output.node_size());
+ EXPECT_EQ(16, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
@@ -132,6 +133,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ("zeros", node.input(0));
EXPECT_EQ("^y", node.input(1));
}
+ } else if (name == "matmul1") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^x", node.input(1));
+ } else if (name == "matmul2") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^y", node.input(1));
} else if (name == "mul3") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));