diff options
author | 2017-12-14 15:58:25 -0800 | |
---|---|---|
committer | 2017-12-14 16:02:36 -0800 | |
commit | 481b5f4410b34b65570f9dce62b34e9199769a38 (patch) | |
tree | f3b235b5cb182e5113b9148857a42f21dccc111d | |
parent | 264e7e8b4b28a84a94310e20fa26d8e8e2a9cd60 (diff) |
Enable associative & commutative operator optimization.
PiperOrigin-RevId: 179111549
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 3 |
3 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 75a11a4d36..24c372a7cf 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -26,7 +26,11 @@ namespace tensorflow { namespace grappler { bool IsAdd(const NodeDef& node) { - return node.op() == "Add" || node.op() == "AddV2"; + if (node.op() == "AddV2" || node.op() == "Add") { + DataType type = node.attr().at("T").type(); + return type != DT_STRING; + } + return false; } bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 360ada4b1c..59df49c245 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1486,8 +1486,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // TODO(rmlarsen): Handle non-associative/non-commutative operators like // subtraction and division, as well as mixed subtraction/addition, // division/multiplication. - if (is_aggressive && (is_add || is_mul) && - NumNonControlInputs(*node) == 2) { + if ((is_add || is_mul) && NumNonControlInputs(*node) == 2) { NodeDef* left_child = node_map_->GetNode(node->input(0)); NodeDef* right_child = node_map_->GetNode(node->input(1)); // One child must be constant, and the other the same op as the parent. @@ -1512,7 +1511,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, continue; } - const int parent_const_input = left_child_is_constant ? 0 : 1; + // Identify the nodes to swap. const NodeDef* left_leaf = node_map_->GetNode(child_node->input(0)); const NodeDef* right_leaf = node_map_->GetNode(child_node->input(1)); const bool left_leaf_is_constant = IsReallyConstant(*left_leaf); @@ -1521,7 +1520,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // Child is already foldable, leave it alone. continue; } - int non_const_leaf_input = left_leaf_is_constant ? 1 : 0; + const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0; + const int parent_const_input = left_child_is_constant ? 0 : 1; // Swap the constant child with a non-constant leaf node. node_map_->UpdateInput(node->name(), node->input(parent_const_input), diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 31e52c7a4e..a3b3e522eb 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -97,11 +97,10 @@ TEST_F(ConstantFoldingTest, AddTree) { item.fetch = {"add_parent", "mul_parent", "addmul_parent"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - LOG(INFO) << "Final results =\n" << output.DebugString(); EXPECT_EQ(9, output.node_size()); |