aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/simple_placer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/simple_placer.cc')
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc18
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index d3110cba04..f6e6bf0692 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -605,7 +605,7 @@ bool IsMetadataNode(const Node* node) {
// outputs that are connected to nodes in the same colocation group.
bool IsGeneratorNode(const Node* node) {
return node->num_inputs() == 0 && node->num_outputs() == 1 &&
- node->out_edges().size() == 1 && !IsRefType(node->output_type(0));
+ !IsRefType(node->output_type(0));
}
} // namespace
@@ -730,9 +730,9 @@ Status SimplePlacer::Run() {
// Heuristic A: prefer to place "generators" with their only
// consumers.
//
- // If this is a node with no inputs and a single (non-ref)
- // consumer, we save this for a second pass, so that the
- // consumer's placement is chosen.
+ // If this is a node with no inputs and one output, we save
+ // this for a second pass, so that the consumer's placement
+ // is chosen.
if (IsGeneratorNode(node)) {
second_pass.push_back(node);
continue;
@@ -794,7 +794,15 @@ Status SimplePlacer::Run() {
if (IsGeneratorNode(node)) {
const Node* output = (*node->out_edges().begin())->dst();
const string& output_device_name = output->assigned_device_name();
- if (CanAssignToDevice(output_device_name, devices)) {
+
+ const bool consumers_on_same_device = std::all_of(
+ node->out_edges().begin(), node->out_edges().end(),
+ [output_device_name](const Edge* e) {
+ return e->dst()->assigned_device_name() == output_device_name;
+ });
+
+ if (consumers_on_same_device &&
+ CanAssignToDevice(output_device_name, devices)) {
assigned_device = output_device_name;
}
}