diff options
author | Benoit Steiner <bsteiner@google.com> | 2018-04-04 16:17:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-04 16:22:19 -0700 |
commit | f8acfb01792886274778d9ad7a9d990cbef14141 (patch) | |
tree | e089cae1d1813458fad1de0eff700d0d5ff57221 | |
parent | e98c13c55e519cb70ede110cd8941f8cb75ab718 (diff) |
Fixed handling of control dependencies in the arithmethic optimizer
PiperOrigin-RevId: 191665098
5 files changed, 64 insertions, 89 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 919f23fd98..59a5695af0 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h" #include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -290,21 +289,16 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all // optimizations will be migrated to stages - void AddFrameControlDeps(const NodeDef* old_node, - const std::vector<NodeDef*>& new_nodes, - const string& source_for_ctrl_dep, - const std::vector<NodeDef*>& sinks_for_control_dep) { - const auto frame_it = ctx_.frame_map->find(old_node); - if (frame_it != ctx_.frame_map->end()) { - for (auto node : new_nodes) { - ctx_.frame_map->emplace(node, frame_it->second); - } - if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) { - const string ctrl_dep = ConstantFolding::AddControlDependency( - source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map); - for (auto node : sinks_for_control_dep) { - MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph, - ctx_.node_map); + void ForwardControlDependencies( + NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) { + for (const auto& src : src_nodes) { + for (int i = src->input_size() - 1; i >= 0; --i) { + if (IsControlInput(src->input(i))) { + *target_node->add_input() = src->input(i); + ctx_.node_map->AddOutput(NodeName(src->input(i)), + target_node->name()); + } else { + break; } } } @@ -703,7 +697,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { CHECK(IsSupported(node)); std::set<string> common_factors; - TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors)); + std::vector<string> ctrl_deps; + TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps)); if (common_factors.size() == 1) { const string& common_factor = *common_factors.begin(); @@ -735,9 +730,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { new_add_node->set_input(i, unique_factors[i]); } - // Add frame dependencies that the original node might have had. - AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, - {new_add_node}); + // Add control deps on add node + for (const string& ctrl_dep : ctrl_deps) { + *new_add_node->add_input() = ctrl_dep; + ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name()); + } // optimize new inner aggregation node AddToOptimizationQueue(new_add_node); @@ -763,14 +760,16 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { } // Determine the set of common factors if the input nodes are all Mul nodes. - Status GetCommonFactors(const NodeDef* node, - std::set<string>* common_factors) const { + Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors, + std::vector<string>* ctrl_deps) const { CHECK(common_factors->empty()); for (int i = 0; i < node->input_size(); ++i) { if (i > 0 && common_factors->empty()) break; - if (IsControlInput(node->input(i))) break; - + if (IsControlInput(node->input(i))) { + ctrl_deps->push_back(node->input(i)); + continue; + } NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input)); @@ -790,6 +789,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { std::inserter(intersection, intersection.begin())); std::swap(*common_factors, intersection); } + for (int i = 2; i < input->input_size(); ++i) { + ctrl_deps->push_back(input->input(i)); + } } return Status::OK(); } @@ -1275,20 +1277,15 @@ void ArithmeticOptimizer::DedupComputations() { } } -void ArithmeticOptimizer::AddFrameControlDeps( - const NodeDef* old_node, const std::vector<NodeDef*>& new_nodes, - const string& source_for_ctrl_dep, - const std::vector<NodeDef*>& sinks_for_control_dep) { - const auto frame_it = frame_map_.find(old_node); - if (frame_it != frame_map_.end()) { - for (auto node : new_nodes) { - frame_map_.emplace(node, frame_it->second); - } - if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) { - const string ctrl_dep = ConstantFolding::AddControlDependency( - source_for_ctrl_dep, optimized_graph_, node_map_.get()); - for (auto node : sinks_for_control_dep) { - MaybeAddControlInput(ctrl_dep, node, optimized_graph_, node_map_.get()); +void ArithmeticOptimizer::ForwardControlDependencies( + NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) { + for (const auto& src : src_nodes) { + for (int i = src->input_size() - 1; i >= 0; --i) { + if (IsControlInput(src->input(i))) { + *target_node->add_input() = src->input(i); + node_map_->AddOutput(NodeName(src->input(i)), target_node->name()); + } else { + break; } } } @@ -1408,10 +1405,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( node_map_->AddOutput(new_transpose->name(), new_cast->name()); nodes_to_simplify->PushBack(new_transpose); - // Add frame dependencies that the original node might have had. - AddFrameControlDeps(node, {new_transpose, new_cast}, - new_transpose->input(0), {new_transpose}); - + ForwardControlDependencies(new_transpose, {cast, node}); return new_cast->name(); } } @@ -1485,7 +1479,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( node_map_->AddOutput(weights->name(), scaled_weights->name()); scaled_weights->add_input(mul->input(1)); node_map_->AddOutput(scale->name(), scaled_weights->name()); - AddFrameControlDeps(node, {scaled_weights}, "", {}); + ForwardControlDependencies(scaled_weights, {source}); // Update `conv`'s weights to `scaled_weights`. conv->set_input(1, scaled_weights->name()); @@ -1521,7 +1515,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) { - // Discard aggregate nodes with a single input. + // Discard aggregate nodes with a single input and no control dependencies. if (node->input_size() == 1) { return node->input(0); } @@ -1567,6 +1561,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( return ""; } new_const_node->set_device(node->device()); + MaybeAddControlInput(NodeName(node->input(0)), new_const_node, + optimized_graph_, node_map_.get()); nodes_to_simplify->PushBack(new_const_node); // 2. Replace the aggregate node with Mul(Const(N), x). @@ -1579,9 +1575,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_mul_node->add_input(node->input(0)); node_map_->AddOutput(node->input(0), new_mul_node->name()); - CopyControlInputs(*node, new_mul_node, optimized_graph_, node_map_.get()); - AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0), - {new_const_node}); + ForwardControlDependencies(new_mul_node, {node}); return new_mul_node->name(); } } @@ -1614,7 +1608,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( FlipBooleanAttr(attr_a, new_op); new_op->set_input(0, a->input(0)); node_map_->UpdateInput(new_op->name(), a->name(), a->input(0)); - AddFrameControlDeps(node, {new_op}, a->input(0), {new_op}); } if (b_is_foldable) { const string attr_b = @@ -1622,10 +1615,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( FlipBooleanAttr(attr_b, new_op); new_op->set_input(1, b->input(0)); node_map_->UpdateInput(new_op->name(), b->name(), b->input(0)); - if (!a_is_foldable) { - AddFrameControlDeps(node, {new_op}, b->input(0), {new_op}); - } } + std::vector<const NodeDef*> deps_to_forward({node}); + if (a_is_foldable) { + deps_to_forward.push_back(a); + } + if (b_is_foldable) { + deps_to_forward.push_back(b); + } + ForwardControlDependencies(new_op, deps_to_forward); } } @@ -1647,7 +1645,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( : "Transpose"); new_op->set_input(0, input->input(0)); node_map_->UpdateInput(new_op->name(), node->name(), input->input(0)); - AddFrameControlDeps(node, {new_op}, "", {}); + ForwardControlDependencies(new_op, {node, input}); return new_op->name(); } } @@ -1663,8 +1661,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { } const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, - graph_properties_.get(), node_map_.get(), - &frame_map_); + graph_properties_.get(), node_map_.get()); const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify); // Stop pipeline after first stage returning non-empty simplified tensor name. @@ -1764,11 +1761,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, graph_properties_.reset(new GraphProperties(item)); TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false)); - // Identify loop frames - int num_frames; - TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, - &frame_map_, &num_frames)); - // Perform the optimizations. TF_RETURN_IF_ERROR(SimplifyArithmeticOps()); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 63a7b55893..7e81ed0a1f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -100,13 +99,9 @@ class ArithmeticOptimizer : public GraphOptimizer { // Dedup redundant nodes in the graph. void DedupComputations(); - // Fix frame dependencies by adding control dependencies from old_input to - // nodes in new_nodes_for_control_dep, and update frame_map for all nodes in - // new_nodes. - void AddFrameControlDeps(const NodeDef* old_node, - const std::vector<NodeDef*>& new_nodes, - const string& source_for_ctrl_dep, - const std::vector<NodeDef*>& sinks_for_control_dep); + // Forward the control dependencies anchored on src_nodes to the target_nodes. + void ForwardControlDependencies(NodeDef* target_node, + const std::vector<const NodeDef*>& src_nodes); // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse // transposes. @@ -135,7 +130,6 @@ class ArithmeticOptimizer : public GraphOptimizer { bool fetch_nodes_known_ = false; std::unordered_set<string> nodes_to_preserve_; std::unique_ptr<NodeMap> node_map_; - FrameMap frame_map_; std::unique_ptr<GraphProperties> graph_properties_; GraphDef* optimized_graph_ = nullptr; // Not owned. }; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 48f1dd5aa1..e117341ba3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -520,26 +520,23 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6")); ASSERT_NE(add_6_node, nullptr); - EXPECT_EQ(3, add_6_node->input_size()); + EXPECT_EQ(2, add_6_node->input_size()); EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0)); EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1)); - EXPECT_EQ("^Placeholder", add_6_node->input(2)); const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4")); ASSERT_NE(add_4_node, nullptr); EXPECT_EQ("Add", add_4_node->op()); - EXPECT_EQ(3, add_4_node->input_size()); + EXPECT_EQ(2, add_4_node->input_size()); EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0)); EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1)); - EXPECT_EQ("^Placeholder", add_4_node->input(2)); const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5")); ASSERT_NE(add_5_node, nullptr); EXPECT_EQ("Add", add_5_node->op()); - EXPECT_EQ(3, add_5_node->input_size()); + EXPECT_EQ(2, add_5_node->input_size()); EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0)); EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1)); - EXPECT_EQ("^Placeholder", add_5_node->input(2)); const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const")); ASSERT_NE(add_const_node, nullptr); diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index 8d3e965c57..7ed0474861 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/grappler/utils/frame.h" namespace tensorflow { namespace grappler { @@ -45,21 +44,16 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name); struct GraphOptimizerContext { GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve, GraphDef* optimized_graph, - GraphProperties* graph_properties, NodeMap* node_map, - FrameMap* frame_map) + GraphProperties* graph_properties, NodeMap* node_map) : nodes_to_preserve(nodes_to_preserve), optimized_graph(optimized_graph), graph_properties(graph_properties), - node_map(node_map), - frame_map(frame_map) {} + node_map(node_map) {} const std::unordered_set<string>* nodes_to_preserve; GraphDef* optimized_graph; GraphProperties* graph_properties; NodeMap* node_map; - // TODO(ezhulenev): it seems that frame_map is only relevant for loop - // optimizer? Move it to loop-optimizer specific context extension. - FrameMap* frame_map; }; Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc index 416327e622..3f5ab87a5a 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc @@ -58,8 +58,8 @@ TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScope_InScope) { TEST_F(GraphOptimizerStageTest, OptimizedNodeName) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ nullptr, - /*graph_properties*/ nullptr, /*node_name*/ nullptr, - /*frame_map*/ nullptr); + /*graph_properties*/ nullptr, + /*node_name*/ nullptr); FakeOptimizerStage stage("my_opt", "my_stg", ctx); const auto node = ParseNodeScopeAndName("a/b/c/Add"); @@ -94,8 +94,7 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, - /*node_name*/ &node_map, - /*frame_map*/ nullptr); + /*node_name*/ &node_map); FakeOptimizerStage stage("my_opt", "my_stg", ctx); NodeDef* add_node; @@ -134,8 +133,7 @@ TEST_F(GraphOptimizerStageTest, AddNodes) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, - /*node_name*/ &node_map, - /*frame_map*/ nullptr); + /*node_name*/ &node_map); FakeOptimizerStage stage("my_opt", "my_stg", ctx); NodeDef* add_node; @@ -165,4 +163,4 @@ TEST_F(GraphOptimizerStageTest, AddNodes) { } // namespace } // end namespace grappler -} // end namespace tensorflow
\ No newline at end of file +} // end namespace tensorflow |