aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-01 16:17:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 16:20:38 -0800
commit59ae88dac4399a8719aebe1b90f87f61fd1fd7e5 (patch)
tree47e271336cb3726745308e4054ae371cd3efddca
parent2270ecc169025c7fa33edd09700abcd72a777373 (diff)
Eliminate matrix multiplication with zeros.
PiperOrigin-RevId: 177655417
-rw-r--r--tensorflow/core/grappler/op_types.cc6
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc31
4 files changed, 48 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 15fcaa857e..571975aca1 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -90,6 +90,12 @@ bool IsIdentity(const NodeDef& node) {
return op == "Identity" || op == "RefIdentity";
}
+bool IsMatMul(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
+ op == "SparseMatMul";
+}
+
bool IsMerge(const NodeDef& node) {
const auto op = node.op();
return op == "Merge" || op == "RefMerge";
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index b1d81448af..47dd2c7faf 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -43,6 +43,7 @@ bool IsFusedBatchNormGradV1(const NodeDef& node);
bool IsIdentity(const NodeDef& node);
bool IsMerge(const NodeDef& node);
bool IsMul(const NodeDef& node);
+bool IsMatMul(const NodeDef& node);
bool IsNextIteration(const NodeDef& node);
bool IsPad(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index e0f39c2931..84f3cc9df7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1317,9 +1317,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// Simplify multiplication by ones or zeros, and addition of zeros.
bool is_mul = IsMul(node);
+ bool is_matmul = IsMatMul(node);
bool is_add = IsAdd(node);
if (opt_level_ == RewriterConfig::AGGRESSIVE && use_shape_info &&
- (is_mul || is_add) && properties.HasInputProperties(node.name()) &&
+ (is_mul || is_matmul || is_add) &&
+ properties.HasInputProperties(node.name()) &&
properties.HasOutputProperties(node.name())) {
const NodeDef* x = node_map_->GetNode(node.input(0));
const NodeDef* y = node_map_->GetNode(node.input(1));
@@ -1335,24 +1337,34 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// Simplify multiplication by or addition of zeros.
const bool x_is_zero = IsZeros(*x);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
- if (x_is_zero && x_matches_output_shape) {
- // 0 * y = 0 or 0 + y = y.
- ReplaceAddOrMulWithIdentity(is_mul ? 0 : 1, &node);
+ if (x_is_zero) {
+ if ((is_mul && x_matches_output_shape) || is_matmul) {
+ // 0 * y = 0
+ ReplaceAddOrMulWithIdentity(0, &node);
+ } else {
+ // 0 + y = y.
+ ReplaceAddOrMulWithIdentity(1, &node);
+ }
continue;
}
const TensorShapeProto& y_shape =
properties.GetInputProperties(node.name())[1].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
- if (y_is_zero && y_matches_output_shape) {
- // x * 0 = 0 or x + 0 = x.
- ReplaceAddOrMulWithIdentity(is_mul ? 1 : 0, &node);
+ if (y_is_zero) {
+ if ((is_mul && y_matches_output_shape) || is_matmul) {
+ // x * 0 = 0
+ ReplaceAddOrMulWithIdentity(1, &node);
+ } else {
+ // x + 0 = y.
+ ReplaceAddOrMulWithIdentity(0, &node);
+ }
continue;
}
if (is_mul) {
- // Simplify multiplication by zeros where the output shape does not
- // match the shape of the zero input.
+ // Simplify scalar multiplication by zeros where, due to broadcasting,
+ // the output shape does not match the shape of the zero input.
if (x_is_zero || y_is_zero) {
TF_RETURN_IF_ERROR(
ReplaceAddOrMulWithConstant(0, output_shape, &node));
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 32a691d3ee..a17ec733ea 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -81,26 +81,27 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
for (bool use_const : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
- ops::Placeholder::Shape(TensorShape({1, 2})));
+ ops::Placeholder::Shape(TensorShape({2, 2})));
Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
- ops::Placeholder::Shape(TensorShape({1, 2})));
- Output zeros =
- !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
- : ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f}, {1, 2});
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
+ : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2});
Output zeros_broadcast =
- ops::Const(s.WithOpName("zeros_broadcast"), {0.0f}, {1, 1});
- Output ones = !use_const
- ? ops::OnesLike(s.WithOpName("ones"), x)
- : ops::Const(s.WithOpName("ones"), {1.0f, 1.0f}, {1, 2});
+ ops::Const(s.WithOpName("zeros_broadcast"), 0.0f, {1, 1});
+ Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x)
+ : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2});
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
Output mul5 = ops::Mul(s.WithOpName("mul1"), x, zeros_broadcast);
Output mul6 = ops::Mul(s.WithOpName("mul2"), zeros_broadcast, y);
+ Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
+ Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
- Output addn = ops::AddN(s, {mul1, mul2, mul3, mul4, add1, add2});
+ Output addn =
+ ops::AddN(s, {mul1, mul2, mul3, mul4, matmul1, matmul2, add1, add2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -110,7 +111,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(14, output.node_size());
+ EXPECT_EQ(16, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
@@ -132,6 +133,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ("zeros", node.input(0));
EXPECT_EQ("^y", node.input(1));
}
+ } else if (name == "matmul1") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^x", node.input(1));
+ } else if (name == "matmul2") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^y", node.input(1));
} else if (name == "mul3") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));