From 59ae88dac4399a8719aebe1b90f87f61fd1fd7e5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Dec 2017 16:17:04 -0800 Subject: Eliminate matrix multiplication with zeros. PiperOrigin-RevId: 177655417 --- tensorflow/core/grappler/op_types.cc | 6 +++++ tensorflow/core/grappler/op_types.h | 1 + .../core/grappler/optimizers/constant_folding.cc | 30 ++++++++++++++------- .../grappler/optimizers/constant_folding_test.cc | 31 ++++++++++++++-------- 4 files changed, 48 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 15fcaa857e..571975aca1 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -90,6 +90,12 @@ bool IsIdentity(const NodeDef& node) { return op == "Identity" || op == "RefIdentity"; } +bool IsMatMul(const NodeDef& node) { + const auto op = node.op(); + return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" || + op == "SparseMatMul"; +} + bool IsMerge(const NodeDef& node) { const auto op = node.op(); return op == "Merge" || op == "RefMerge"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index b1d81448af..47dd2c7faf 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -43,6 +43,7 @@ bool IsFusedBatchNormGradV1(const NodeDef& node); bool IsIdentity(const NodeDef& node); bool IsMerge(const NodeDef& node); bool IsMul(const NodeDef& node); +bool IsMatMul(const NodeDef& node); bool IsNextIteration(const NodeDef& node); bool IsPad(const NodeDef& node); bool IsNoOp(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index e0f39c2931..84f3cc9df7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1317,9 +1317,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // Simplify multiplication by ones or zeros, and addition of zeros. bool is_mul = IsMul(node); + bool is_matmul = IsMatMul(node); bool is_add = IsAdd(node); if (opt_level_ == RewriterConfig::AGGRESSIVE && use_shape_info && - (is_mul || is_add) && properties.HasInputProperties(node.name()) && + (is_mul || is_matmul || is_add) && + properties.HasInputProperties(node.name()) && properties.HasOutputProperties(node.name())) { const NodeDef* x = node_map_->GetNode(node.input(0)); const NodeDef* y = node_map_->GetNode(node.input(1)); @@ -1335,24 +1337,34 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // Simplify multiplication by or addition of zeros. const bool x_is_zero = IsZeros(*x); const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); - if (x_is_zero && x_matches_output_shape) { - // 0 * y = 0 or 0 + y = y. - ReplaceAddOrMulWithIdentity(is_mul ? 0 : 1, &node); + if (x_is_zero) { + if ((is_mul && x_matches_output_shape) || is_matmul) { + // 0 * y = 0 + ReplaceAddOrMulWithIdentity(0, &node); + } else { + // 0 + y = y. + ReplaceAddOrMulWithIdentity(1, &node); + } continue; } const TensorShapeProto& y_shape = properties.GetInputProperties(node.name())[1].shape(); const bool y_is_zero = IsZeros(*y); const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); - if (y_is_zero && y_matches_output_shape) { - // x * 0 = 0 or x + 0 = x. - ReplaceAddOrMulWithIdentity(is_mul ? 1 : 0, &node); + if (y_is_zero) { + if ((is_mul && y_matches_output_shape) || is_matmul) { + // x * 0 = 0 + ReplaceAddOrMulWithIdentity(1, &node); + } else { + // x + 0 = y. + ReplaceAddOrMulWithIdentity(0, &node); + } continue; } if (is_mul) { - // Simplify multiplication by zeros where the output shape does not - // match the shape of the zero input. + // Simplify scalar multiplication by zeros where, due to broadcasting, + // the output shape does not match the shape of the zero input. if (x_is_zero || y_is_zero) { TF_RETURN_IF_ERROR( ReplaceAddOrMulWithConstant(0, output_shape, &node)); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 32a691d3ee..a17ec733ea 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -81,26 +81,27 @@ TEST_F(ConstantFoldingTest, NeutralElement) { for (bool use_const : {true, false}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, - ops::Placeholder::Shape(TensorShape({1, 2}))); + ops::Placeholder::Shape(TensorShape({2, 2}))); Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, - ops::Placeholder::Shape(TensorShape({1, 2}))); - Output zeros = - !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x) - : ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f}, {1, 2}); + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x) + : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2}); Output zeros_broadcast = - ops::Const(s.WithOpName("zeros_broadcast"), {0.0f}, {1, 1}); - Output ones = !use_const - ? ops::OnesLike(s.WithOpName("ones"), x) - : ops::Const(s.WithOpName("ones"), {1.0f, 1.0f}, {1, 2}); + ops::Const(s.WithOpName("zeros_broadcast"), 0.0f, {1, 1}); + Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x) + : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2}); Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y); Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones); Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y); Output mul5 = ops::Mul(s.WithOpName("mul1"), x, zeros_broadcast); Output mul6 = ops::Mul(s.WithOpName("mul2"), zeros_broadcast, y); + Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros); + Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y); Output add1 = ops::Add(s.WithOpName("add1"), x, zeros); Output add2 = ops::Add(s.WithOpName("add2"), zeros, y); - Output addn = ops::AddN(s, {mul1, mul2, mul3, mul4, add1, add2}); + Output addn = + ops::AddN(s, {mul1, mul2, mul3, mul4, matmul1, matmul2, add1, add2}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -110,7 +111,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(14, output.node_size()); + EXPECT_EQ(16, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); const string& name = node.name(); @@ -132,6 +133,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("zeros", node.input(0)); EXPECT_EQ("^y", node.input(1)); } + } else if (name == "matmul1") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("zeros", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + } else if (name == "matmul2") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("zeros", node.input(0)); + EXPECT_EQ("^y", node.input(1)); } else if (name == "mul3") { EXPECT_EQ("Identity", node.op()); EXPECT_EQ("x", node.input(0)); -- cgit v1.2.3