diff options
Diffstat (limited to 'tensorflow/core/common_runtime/simple_placer.cc')
-rw-r--r-- | tensorflow/core/common_runtime/simple_placer.cc | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index d3110cba04..f6e6bf0692 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -605,7 +605,7 @@ bool IsMetadataNode(const Node* node) { // outputs that are connected to nodes in the same colocation group. bool IsGeneratorNode(const Node* node) { return node->num_inputs() == 0 && node->num_outputs() == 1 && - node->out_edges().size() == 1 && !IsRefType(node->output_type(0)); + !IsRefType(node->output_type(0)); } } // namespace @@ -730,9 +730,9 @@ Status SimplePlacer::Run() { // Heuristic A: prefer to place "generators" with their only // consumers. // - // If this is a node with no inputs and a single (non-ref) - // consumer, we save this for a second pass, so that the - // consumer's placement is chosen. + // If this is a node with no inputs and one output, we save + // this for a second pass, so that the consumer's placement + // is chosen. if (IsGeneratorNode(node)) { second_pass.push_back(node); continue; @@ -794,7 +794,15 @@ Status SimplePlacer::Run() { if (IsGeneratorNode(node)) { const Node* output = (*node->out_edges().begin())->dst(); const string& output_device_name = output->assigned_device_name(); - if (CanAssignToDevice(output_device_name, devices)) { + + const bool consumers_on_same_device = std::all_of( + node->out_edges().begin(), node->out_edges().end(), + [output_device_name](const Edge* e) { + return e->dst()->assigned_device_name() == output_device_name; + }); + + if (consumers_on_same_device && + CanAssignToDevice(output_device_name, devices)) { assigned_device = output_device_name; } } |