aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-13 10:39:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 10:43:41 -0700
commitea9e65c94ad71ca86d2be91c4109c62269b42cf8 (patch)
tree426f2930ba290c9360c083f0fe64121715b1264c
parent29c129c69d3aa602d0fa083f7477144466f562d1 (diff)
Enable arithmetic optimizations for Fill nodes that are all zeros or ones.
PiperOrigin-RevId: 188893722
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc58
2 files changed, 48 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index a4d8376667..21037ff794 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1377,6 +1377,10 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
if (node.op() == "OnesLike") {
return true;
}
+ if (node.op() == "Fill") {
+ NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
+ return values != nullptr && IsOnes(*values);
+ }
if (node.op() != "Const") {
return false;
}
@@ -1408,6 +1412,10 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
if (node.op() == "ZerosLike") {
return true;
}
+ if (node.op() == "Fill") {
+ NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
+ return values != nullptr && IsZeros(*values);
+ }
if (!IsConstant(node)) {
return false;
}
@@ -1846,7 +1854,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
const TensorShapeProto& y_shape =
properties->GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
- const bool x_is_one = IsOnes(*x);
+ const bool x_is_one = x_is_zero ? false : IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
@@ -1873,7 +1881,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
const TensorShapeProto& x_shape =
properties->GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
- const bool y_is_one = IsOnes(*y);
+ const bool y_is_one = y_is_zero ? false : IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 724fb84f3e..cf151d4c4b 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -152,7 +152,10 @@ TEST_F(ConstantFoldingTest, AddTree) {
}
TEST_F(ConstantFoldingTest, NeutralElement) {
- for (bool use_const : {true, false}) {
+ int kConst = 0;
+ int kLike = 1;
+ int kFill = 2;
+ for (int const_type : {kConst, kLike, kFill}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
@@ -164,11 +167,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
ops::Placeholder::Shape(TensorShape({2, 3})));
Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2})));
- Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
- : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2});
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
- Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x)
- : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2});
+ Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
+ Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
+ Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
+ Output zeros = const_type == kConst
+ ? zeros_const
+ : (const_type == kLike ? zeros_like : zeros_fill);
+ Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
+ Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
+ Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
+ Output ones = const_type == kConst
+ ? ones_const
+ : (const_type == kLike ? ones_like : ones_fill);
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
@@ -201,6 +212,13 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ const string suffix =
+ (const_type == kConst ? "_const"
+ : (const_type == kLike ? "_like" : "_fill"));
+ const string zeros_name = strings::StrCat("zeros", suffix);
+ const string ones_name = strings::StrCat("ones", suffix);
+ const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
+ const string ctrl_ones_name = strings::StrCat("^ones", suffix);
EXPECT_EQ(28, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
@@ -208,19 +226,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
if (name == "mul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "mul2") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "mul3") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul4") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul5") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
@@ -232,23 +250,23 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} else if (name == "div1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "div2") {
EXPECT_EQ("Reciprocal", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
+ EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "matmul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "matmul2") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "matmul3") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^a", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(1, t.float_val_size());
EXPECT_EQ(0, t.float_val(0));
@@ -257,7 +275,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ(2, t.tensor_shape().dim(1).size());
} else if (name == "matmul4") {
EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^zeros", node.input(0));
+ EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^b", node.input(1));
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(1, t.float_val_size());
@@ -268,11 +286,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} else if (name == "add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "add2") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "bias_add1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
@@ -280,16 +298,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
} else if (name == "bias_add2") {
// We don't eliminate this one, because it requires broadcasting.
EXPECT_EQ("BiasAdd", node.op());
- EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ(zeros_name, node.input(0));
EXPECT_EQ("bias", node.input(1));
} else if (name == "sub1") {
EXPECT_EQ("Snapshot", node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "sub2") {
EXPECT_EQ("Neg", node.op());
EXPECT_EQ("y", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
+ EXPECT_EQ(ctrl_zeros_name, node.input(1));
}
const std::set<string> square_zero_const{"mul1", "mul2", "mul5",
"mul6", "matmul1", "matmul2"};