aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-14 15:58:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 16:02:36 -0800
commit481b5f4410b34b65570f9dce62b34e9199769a38 (patch)
treef3b235b5cb182e5113b9148857a42f21dccc111d
parent264e7e8b4b28a84a94310e20fa26d8e8e2a9cd60 (diff)
Enable associative & commutative operator optimization.
PiperOrigin-RevId: 179111549
-rw-r--r--tensorflow/core/grappler/op_types.cc6
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc3
3 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 75a11a4d36..24c372a7cf 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -26,7 +26,11 @@ namespace tensorflow {
namespace grappler {
bool IsAdd(const NodeDef& node) {
- return node.op() == "Add" || node.op() == "AddV2";
+ if (node.op() == "AddV2" || node.op() == "Add") {
+ DataType type = node.attr().at("T").type();
+ return type != DT_STRING;
+ }
+ return false;
}
bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 360ada4b1c..59df49c245 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1486,8 +1486,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// TODO(rmlarsen): Handle non-associative/non-commutative operators like
// subtraction and division, as well as mixed subtraction/addition,
// division/multiplication.
- if (is_aggressive && (is_add || is_mul) &&
- NumNonControlInputs(*node) == 2) {
+ if ((is_add || is_mul) && NumNonControlInputs(*node) == 2) {
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
// One child must be constant, and the other the same op as the parent.
@@ -1512,7 +1511,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
continue;
}
- const int parent_const_input = left_child_is_constant ? 0 : 1;
+ // Identify the nodes to swap.
const NodeDef* left_leaf = node_map_->GetNode(child_node->input(0));
const NodeDef* right_leaf = node_map_->GetNode(child_node->input(1));
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
@@ -1521,7 +1520,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// Child is already foldable, leave it alone.
continue;
}
- int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
+ const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
+ const int parent_const_input = left_child_is_constant ? 0 : 1;
// Swap the constant child with a non-constant leaf node.
node_map_->UpdateInput(node->name(), node->input(parent_const_input),
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 31e52c7a4e..a3b3e522eb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -97,11 +97,10 @@ TEST_F(ConstantFoldingTest, AddTree) {
item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */);
+ ConstantFolding fold(nullptr /* cpu_device */);
GraphDef output;
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- LOG(INFO) << "Final results =\n" << output.DebugString();
EXPECT_EQ(9, output.node_size());