aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-04-04 16:17:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-04 16:22:19 -0700
commitf8acfb01792886274778d9ad7a9d990cbef14141 (patch)
treee089cae1d1813458fad1de0eff700d0d5ff57221
parente98c13c55e519cb70ede110cd8941f8cb75ab718 (diff)
Fixed handling of control dependencies in the arithmethic optimizer
PiperOrigin-RevId: 191665098
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc110
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h12
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc9
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h10
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc12
5 files changed, 64 insertions, 89 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 919f23fd98..59a5695af0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -290,21 +289,16 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
// TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
// optimizations will be migrated to stages
- void AddFrameControlDeps(const NodeDef* old_node,
- const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep) {
- const auto frame_it = ctx_.frame_map->find(old_node);
- if (frame_it != ctx_.frame_map->end()) {
- for (auto node : new_nodes) {
- ctx_.frame_map->emplace(node, frame_it->second);
- }
- if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map);
- for (auto node : sinks_for_control_dep) {
- MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph,
- ctx_.node_map);
+ void ForwardControlDependencies(
+ NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+ for (const auto& src : src_nodes) {
+ for (int i = src->input_size() - 1; i >= 0; --i) {
+ if (IsControlInput(src->input(i))) {
+ *target_node->add_input() = src->input(i);
+ ctx_.node_map->AddOutput(NodeName(src->input(i)),
+ target_node->name());
+ } else {
+ break;
}
}
}
@@ -703,7 +697,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
CHECK(IsSupported(node));
std::set<string> common_factors;
- TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors));
+ std::vector<string> ctrl_deps;
+ TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps));
if (common_factors.size() == 1) {
const string& common_factor = *common_factors.begin();
@@ -735,9 +730,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
new_add_node->set_input(i, unique_factors[i]);
}
- // Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
- {new_add_node});
+ // Add control deps on add node
+ for (const string& ctrl_dep : ctrl_deps) {
+ *new_add_node->add_input() = ctrl_dep;
+ ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
+ }
// optimize new inner aggregation node
AddToOptimizationQueue(new_add_node);
@@ -763,14 +760,16 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
}
// Determine the set of common factors if the input nodes are all Mul nodes.
- Status GetCommonFactors(const NodeDef* node,
- std::set<string>* common_factors) const {
+ Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
+ std::vector<string>* ctrl_deps) const {
CHECK(common_factors->empty());
for (int i = 0; i < node->input_size(); ++i) {
if (i > 0 && common_factors->empty()) break;
- if (IsControlInput(node->input(i))) break;
-
+ if (IsControlInput(node->input(i))) {
+ ctrl_deps->push_back(node->input(i));
+ continue;
+ }
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
@@ -790,6 +789,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
std::inserter(intersection, intersection.begin()));
std::swap(*common_factors, intersection);
}
+ for (int i = 2; i < input->input_size(); ++i) {
+ ctrl_deps->push_back(input->input(i));
+ }
}
return Status::OK();
}
@@ -1275,20 +1277,15 @@ void ArithmeticOptimizer::DedupComputations() {
}
}
-void ArithmeticOptimizer::AddFrameControlDeps(
- const NodeDef* old_node, const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep) {
- const auto frame_it = frame_map_.find(old_node);
- if (frame_it != frame_map_.end()) {
- for (auto node : new_nodes) {
- frame_map_.emplace(node, frame_it->second);
- }
- if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- source_for_ctrl_dep, optimized_graph_, node_map_.get());
- for (auto node : sinks_for_control_dep) {
- MaybeAddControlInput(ctrl_dep, node, optimized_graph_, node_map_.get());
+void ArithmeticOptimizer::ForwardControlDependencies(
+ NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+ for (const auto& src : src_nodes) {
+ for (int i = src->input_size() - 1; i >= 0; --i) {
+ if (IsControlInput(src->input(i))) {
+ *target_node->add_input() = src->input(i);
+ node_map_->AddOutput(NodeName(src->input(i)), target_node->name());
+ } else {
+ break;
}
}
}
@@ -1408,10 +1405,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
node_map_->AddOutput(new_transpose->name(), new_cast->name());
nodes_to_simplify->PushBack(new_transpose);
- // Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_transpose, new_cast},
- new_transpose->input(0), {new_transpose});
-
+ ForwardControlDependencies(new_transpose, {cast, node});
return new_cast->name();
}
}
@@ -1485,7 +1479,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
node_map_->AddOutput(weights->name(), scaled_weights->name());
scaled_weights->add_input(mul->input(1));
node_map_->AddOutput(scale->name(), scaled_weights->name());
- AddFrameControlDeps(node, {scaled_weights}, "", {});
+ ForwardControlDependencies(scaled_weights, {source});
// Update `conv`'s weights to `scaled_weights`.
conv->set_input(1, scaled_weights->name());
@@ -1521,7 +1515,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
- // Discard aggregate nodes with a single input.
+ // Discard aggregate nodes with a single input and no control dependencies.
if (node->input_size() == 1) {
return node->input(0);
}
@@ -1567,6 +1561,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
return "";
}
new_const_node->set_device(node->device());
+ MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
+ optimized_graph_, node_map_.get());
nodes_to_simplify->PushBack(new_const_node);
// 2. Replace the aggregate node with Mul(Const(N), x).
@@ -1579,9 +1575,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_mul_node->add_input(node->input(0));
node_map_->AddOutput(node->input(0), new_mul_node->name());
- CopyControlInputs(*node, new_mul_node, optimized_graph_, node_map_.get());
- AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0),
- {new_const_node});
+ ForwardControlDependencies(new_mul_node, {node});
return new_mul_node->name();
}
}
@@ -1614,7 +1608,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
FlipBooleanAttr(attr_a, new_op);
new_op->set_input(0, a->input(0));
node_map_->UpdateInput(new_op->name(), a->name(), a->input(0));
- AddFrameControlDeps(node, {new_op}, a->input(0), {new_op});
}
if (b_is_foldable) {
const string attr_b =
@@ -1622,10 +1615,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
FlipBooleanAttr(attr_b, new_op);
new_op->set_input(1, b->input(0));
node_map_->UpdateInput(new_op->name(), b->name(), b->input(0));
- if (!a_is_foldable) {
- AddFrameControlDeps(node, {new_op}, b->input(0), {new_op});
- }
}
+ std::vector<const NodeDef*> deps_to_forward({node});
+ if (a_is_foldable) {
+ deps_to_forward.push_back(a);
+ }
+ if (b_is_foldable) {
+ deps_to_forward.push_back(b);
+ }
+ ForwardControlDependencies(new_op, deps_to_forward);
}
}
@@ -1647,7 +1645,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
: "Transpose");
new_op->set_input(0, input->input(0));
node_map_->UpdateInput(new_op->name(), node->name(), input->input(0));
- AddFrameControlDeps(node, {new_op}, "", {});
+ ForwardControlDependencies(new_op, {node, input});
return new_op->name();
}
}
@@ -1663,8 +1661,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
}
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- graph_properties_.get(), node_map_.get(),
- &frame_map_);
+ graph_properties_.get(), node_map_.get());
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
@@ -1764,11 +1761,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
graph_properties_.reset(new GraphProperties(item));
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
- // Identify loop frames
- int num_frames;
- TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
- &frame_map_, &num_frames));
-
// Perform the optimizations.
TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 63a7b55893..7e81ed0a1f 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -100,13 +99,9 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Dedup redundant nodes in the graph.
void DedupComputations();
- // Fix frame dependencies by adding control dependencies from old_input to
- // nodes in new_nodes_for_control_dep, and update frame_map for all nodes in
- // new_nodes.
- void AddFrameControlDeps(const NodeDef* old_node,
- const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep);
+ // Forward the control dependencies anchored on src_nodes to the target_nodes.
+ void ForwardControlDependencies(NodeDef* target_node,
+ const std::vector<const NodeDef*>& src_nodes);
// Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse
// transposes.
@@ -135,7 +130,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool fetch_nodes_known_ = false;
std::unordered_set<string> nodes_to_preserve_;
std::unique_ptr<NodeMap> node_map_;
- FrameMap frame_map_;
std::unique_ptr<GraphProperties> graph_properties_;
GraphDef* optimized_graph_ = nullptr; // Not owned.
};
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 48f1dd5aa1..e117341ba3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -520,26 +520,23 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
ASSERT_NE(add_6_node, nullptr);
- EXPECT_EQ(3, add_6_node->input_size());
+ EXPECT_EQ(2, add_6_node->input_size());
EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
- EXPECT_EQ("^Placeholder", add_6_node->input(2));
const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
ASSERT_NE(add_4_node, nullptr);
EXPECT_EQ("Add", add_4_node->op());
- EXPECT_EQ(3, add_4_node->input_size());
+ EXPECT_EQ(2, add_4_node->input_size());
EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
- EXPECT_EQ("^Placeholder", add_4_node->input(2));
const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
ASSERT_NE(add_5_node, nullptr);
EXPECT_EQ("Add", add_5_node->op());
- EXPECT_EQ(3, add_5_node->input_size());
+ EXPECT_EQ(2, add_5_node->input_size());
EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
- EXPECT_EQ("^Placeholder", add_5_node->input(2));
const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
ASSERT_NE(add_const_node, nullptr);
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 8d3e965c57..7ed0474861 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
namespace tensorflow {
namespace grappler {
@@ -45,21 +44,16 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name);
struct GraphOptimizerContext {
GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph,
- GraphProperties* graph_properties, NodeMap* node_map,
- FrameMap* frame_map)
+ GraphProperties* graph_properties, NodeMap* node_map)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
- node_map(node_map),
- frame_map(frame_map) {}
+ node_map(node_map) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
- // TODO(ezhulenev): it seems that frame_map is only relevant for loop
- // optimizer? Move it to loop-optimizer specific context extension.
- FrameMap* frame_map;
};
Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
index 416327e622..3f5ab87a5a 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
@@ -58,8 +58,8 @@ TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScope_InScope) {
TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ nullptr,
- /*graph_properties*/ nullptr, /*node_name*/ nullptr,
- /*frame_map*/ nullptr);
+ /*graph_properties*/ nullptr,
+ /*node_name*/ nullptr);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
const auto node = ParseNodeScopeAndName("a/b/c/Add");
@@ -94,8 +94,7 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map,
- /*frame_map*/ nullptr);
+ /*node_name*/ &node_map);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
@@ -134,8 +133,7 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map,
- /*frame_map*/ nullptr);
+ /*node_name*/ &node_map);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
@@ -165,4 +163,4 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
} // namespace
} // end namespace grappler
-} // end namespace tensorflow \ No newline at end of file
+} // end namespace tensorflow