diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-12 09:12:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-12 09:16:44 -0700 |
commit | 6d3bb6cac26684a2553a7a9fa04dd5b12f5434f3 (patch) | |
tree | fb1b8e5a1a55d33ff3d8d9ec9bc31ebd059d13b1 | |
parent | e1066ba1a4166ba5ff7ca02ae70e5c44fc385789 (diff) |
Don't remove identity nodes if they follow a device crossing and have consumers on a device different than themselves. They may be used to cache or route data between devices in a deliberate manner.
Simplify code in DependencyOptimizer a bit.
PiperOrigin-RevId: 188730185
-rw-r--r-- | tensorflow/core/grappler/optimizers/dependency_optimizer.cc | 58 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc | 51 |
2 files changed, 80 insertions, 29 deletions
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index a5b2572c9c..63bc19630d 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -274,12 +274,17 @@ void DependencyOptimizer::OptimizeNode(int node_idx, // +----------+ y --^> b if (is_noop || is_identity) { + if (is_identity && !SafeToRemoveIdentity(*node)) { + return; + } + const auto& output_node_set = node_map_->GetOutputs(node_name); const std::vector<NodeDef*> output_nodes(output_node_set.begin(), output_node_set.end()); const int num_outputs = output_nodes.size(); const int num_inputs = node->input_size(); + // Don't increase the number of edges in the graph. if (num_inputs * num_outputs > num_inputs + num_outputs) { return; } @@ -293,39 +298,34 @@ void DependencyOptimizer::OptimizeNode(int node_idx, input_nodes.push_back(input_node); } - // Make sure that we don't increase the number of edges that cross - // device boundaries. - if ((num_inputs == 1 && num_outputs > 1 && - input_nodes[0]->device() != node->device()) || - (num_inputs > 1 && num_outputs == 1 && - output_nodes[0]->device() != node->device())) { + // TODO(rmlarsen): Not all device crossings are equally expensive. + // Assign a cost to each based on device affinity and compute a + // cost before and after. + const string& node_dev = node->device(); + int num_cross_in = 0; + for (NodeDef* input_node : input_nodes) { + num_cross_in += static_cast<int>(input_node->device() != node_dev); + } + int num_cross_out = 0; + for (NodeDef* output_node : output_nodes) { + num_cross_out += static_cast<int>(output_node->device() != node_dev); + } + if (is_identity && num_cross_in > 0 && num_cross_out > 0) { + // This identity node follows a device crossing, so it might be + // following a _Recv node after partioning. Do not remove such nodes, + // unless they only have consumers on the same device as themselves. return; } - if (num_inputs == 2 && num_outputs == 2) { - const string& noop_dev = node->device(); - const string& in0_dev = input_nodes[0]->device(); - const string& in1_dev = input_nodes[1]->device(); - const string& out0_dev = output_nodes[0]->device(); - const string& out1_dev = output_nodes[1]->device(); - const int num_cross_before = static_cast<int>(in0_dev != noop_dev) + - static_cast<int>(in1_dev != noop_dev) + - static_cast<int>(out0_dev != noop_dev) + - static_cast<int>(out1_dev != noop_dev); - const int num_cross_after = static_cast<int>(in0_dev != out0_dev) + - static_cast<int>(in0_dev != out1_dev) + - static_cast<int>(in1_dev != out0_dev) + - static_cast<int>(in1_dev != out1_dev); - if (num_cross_after > num_cross_before) { - return; - } - // To avoid potentially removing Identity nodes following _Recv nodes, - // we require that no device crossings occur in that case. - // TODO(rmlarsen): See if we can relax this condition. - if (is_identity && (num_cross_after > 0 || num_cross_before > 0)) { - return; + const int num_cross_before = num_cross_in + num_cross_out; + int num_cross_after = 0; + for (NodeDef* input_node : input_nodes) { + for (NodeDef* output_node : output_nodes) { + num_cross_after += + static_cast<int>(input_node->device() != output_node->device()); } } - if (is_identity && !SafeToRemoveIdentity(*node)) { + if (num_cross_after > num_cross_before) { + // Avoid increasing the number of device crossings. return; } diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index b66cc17a72..cc1e142041 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -595,6 +595,57 @@ TEST_F(DependencyOptimizerTest, IdentityN) { EXPECT_EQ("id_b:1", output.node(8).input(0)); } +TEST_F(DependencyOptimizerTest, + Identity_DeviceCrossing_ConsumerOnDifferentDevice) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x_on_1 = + ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {}); + Output one_on_3 = + ops::Const(s.WithOpName("one_on_3").WithDevice("/gpu:3"), {1.0f}, {}); + Output x_on_2 = + ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1); + Output result = + ops::Add(s.WithOpName("result").WithDevice("/gpu:3"), x_on_2, one_on_3); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"result"}; + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + +TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x_on_1 = + ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {}); + Output one_on_2 = + ops::Const(s.WithOpName("one_on_2").WithDevice("/gpu:2"), {1.0f}, {}); + Output x_on_2 = + ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1); + Output result = + ops::Add(s.WithOpName("result").WithDevice("/gpu:2"), x_on_2, one_on_2); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"result"}; + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + EXPECT_EQ(3, output.node_size()); + for (const auto& node : output.node()) { + EXPECT_NE("x_on_2", node.name()); + if (node.name() == "result") { + EXPECT_EQ("x_on_1", node.input(0)); + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow |