aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-12 09:12:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 09:16:44 -0700
commit6d3bb6cac26684a2553a7a9fa04dd5b12f5434f3 (patch)
treefb1b8e5a1a55d33ff3d8d9ec9bc31ebd059d13b1
parente1066ba1a4166ba5ff7ca02ae70e5c44fc385789 (diff)
Don't remove identity nodes if they follow a device crossing and have consumers on a device different than themselves. They may be used to cache or route data between devices in a deliberate manner.
Simplify code in DependencyOptimizer a bit. PiperOrigin-RevId: 188730185
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.cc58
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc51
2 files changed, 80 insertions, 29 deletions
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index a5b2572c9c..63bc19630d 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -274,12 +274,17 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
// +----------+ y --^> b
if (is_noop || is_identity) {
+ if (is_identity && !SafeToRemoveIdentity(*node)) {
+ return;
+ }
+
const auto& output_node_set = node_map_->GetOutputs(node_name);
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
output_node_set.end());
const int num_outputs = output_nodes.size();
const int num_inputs = node->input_size();
+ // Don't increase the number of edges in the graph.
if (num_inputs * num_outputs > num_inputs + num_outputs) {
return;
}
@@ -293,39 +298,34 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
input_nodes.push_back(input_node);
}
- // Make sure that we don't increase the number of edges that cross
- // device boundaries.
- if ((num_inputs == 1 && num_outputs > 1 &&
- input_nodes[0]->device() != node->device()) ||
- (num_inputs > 1 && num_outputs == 1 &&
- output_nodes[0]->device() != node->device())) {
+ // TODO(rmlarsen): Not all device crossings are equally expensive.
+ // Assign a cost to each based on device affinity and compute a
+ // cost before and after.
+ const string& node_dev = node->device();
+ int num_cross_in = 0;
+ for (NodeDef* input_node : input_nodes) {
+ num_cross_in += static_cast<int>(input_node->device() != node_dev);
+ }
+ int num_cross_out = 0;
+ for (NodeDef* output_node : output_nodes) {
+ num_cross_out += static_cast<int>(output_node->device() != node_dev);
+ }
+ if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
+ // This identity node follows a device crossing, so it might be
+ // following a _Recv node after partioning. Do not remove such nodes,
+ // unless they only have consumers on the same device as themselves.
return;
}
- if (num_inputs == 2 && num_outputs == 2) {
- const string& noop_dev = node->device();
- const string& in0_dev = input_nodes[0]->device();
- const string& in1_dev = input_nodes[1]->device();
- const string& out0_dev = output_nodes[0]->device();
- const string& out1_dev = output_nodes[1]->device();
- const int num_cross_before = static_cast<int>(in0_dev != noop_dev) +
- static_cast<int>(in1_dev != noop_dev) +
- static_cast<int>(out0_dev != noop_dev) +
- static_cast<int>(out1_dev != noop_dev);
- const int num_cross_after = static_cast<int>(in0_dev != out0_dev) +
- static_cast<int>(in0_dev != out1_dev) +
- static_cast<int>(in1_dev != out0_dev) +
- static_cast<int>(in1_dev != out1_dev);
- if (num_cross_after > num_cross_before) {
- return;
- }
- // To avoid potentially removing Identity nodes following _Recv nodes,
- // we require that no device crossings occur in that case.
- // TODO(rmlarsen): See if we can relax this condition.
- if (is_identity && (num_cross_after > 0 || num_cross_before > 0)) {
- return;
+ const int num_cross_before = num_cross_in + num_cross_out;
+ int num_cross_after = 0;
+ for (NodeDef* input_node : input_nodes) {
+ for (NodeDef* output_node : output_nodes) {
+ num_cross_after +=
+ static_cast<int>(input_node->device() != output_node->device());
}
}
- if (is_identity && !SafeToRemoveIdentity(*node)) {
+ if (num_cross_after > num_cross_before) {
+ // Avoid increasing the number of device crossings.
return;
}
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index b66cc17a72..cc1e142041 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -595,6 +595,57 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
EXPECT_EQ("id_b:1", output.node(8).input(0));
}
+TEST_F(DependencyOptimizerTest,
+ Identity_DeviceCrossing_ConsumerOnDifferentDevice) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x_on_1 =
+ ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
+ Output one_on_3 =
+ ops::Const(s.WithOpName("one_on_3").WithDevice("/gpu:3"), {1.0f}, {});
+ Output x_on_2 =
+ ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
+ Output result =
+ ops::Add(s.WithOpName("result").WithDevice("/gpu:3"), x_on_2, one_on_3);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"result"};
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ VerifyGraphsEqual(item.graph, output, __FUNCTION__);
+}
+
+TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x_on_1 =
+ ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
+ Output one_on_2 =
+ ops::Const(s.WithOpName("one_on_2").WithDevice("/gpu:2"), {1.0f}, {});
+ Output x_on_2 =
+ ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
+ Output result =
+ ops::Add(s.WithOpName("result").WithDevice("/gpu:2"), x_on_2, one_on_2);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"result"};
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+ EXPECT_EQ(3, output.node_size());
+ for (const auto& node : output.node()) {
+ EXPECT_NE("x_on_2", node.name());
+ if (node.name() == "result") {
+ EXPECT_EQ("x_on_1", node.input(0));
+ }
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow