aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/placer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/placer.cc')
-rw-r--r--tensorflow/core/common_runtime/placer.cc880
1 files changed, 880 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
new file mode 100644
index 0000000000..73fdf60fd5
--- /dev/null
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -0,0 +1,880 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/placer.h"
+
+#include <memory>
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+namespace {
+
+// We hoist the conversion from C-style string literal to StringPiece here,
+// so that we can avoid the many repeated calls to strlen().
+const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
+const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
+
+// Returns a list of devices sorted by preferred type and then name
+// from 'devices' whose type is in 'supported_device_types'. This
+// function searches the device types in 'supported_device_types' and
+// returns the subset of devices that match.
+std::vector<Device*> FilterSupportedDevices(
+ const std::vector<Device*>& devices,
+ const DeviceTypeVector& supported_device_types) {
+ std::vector<Device*> filtered_devices;
+ for (const DeviceType& d : supported_device_types) {
+ for (Device* device : devices) {
+ if (DeviceType(device->attributes().device_type()) == d) {
+ filtered_devices.emplace_back(device);
+ }
+ }
+ }
+
+ auto device_sort = [](const Device* a, const Device* b) {
+ auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type()));
+ auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type()));
+ // First sort by prioritized device type (higher is preferred) and
+ // then by device name (lexicographically).
+ if (a_priority != b_priority) {
+ return a_priority > b_priority;
+ }
+ return StringPiece(a->name()) < StringPiece(b->name());
+ };
+ std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
+ return filtered_devices;
+}
+
+// This class maintains the connected components of a colocation
+// constraint graph, and uses this information to assign a satisfying
+// device placement to the nodes of the graph.
+//
+// The typical usage pattern is:
+//
+// Graph graph = ...;
+// DeviceSet device_set = ...;
+// ColocationGraph colocation_graph(graph, device_set);
+//
+// // Add all the nodes of graph to colocation_graph.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node));
+// }
+//
+// // Add one or more colocation constraint.
+// Node node_1 = *graph.FindNodeId(...);
+// Node node_2 = *graph.FindNodeId(...);
+// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2));
+//
+// // Assign devices based on the accumulated constraints.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node));
+// }
+//
+// The implementation uses the union-find algorithm to maintain the
+// connected components efficiently and incrementally as edges
+// (implied by ColocationGraph::ColocateNodes() invocations) are added.
+class ColocationGraph {
+ public:
+ ColocationGraph(Graph* graph, const DeviceSet* device_set,
+ bool allow_soft_placement)
+ : graph_(graph),
+ device_set_(device_set),
+ device_types_(device_set->PrioritizedDeviceTypeList()),
+ allow_soft_placement_(allow_soft_placement) {
+ members_.resize(graph->num_node_ids());
+ }
+
+ // Adds each node of the Graph to this ColocationGraph as a singleton.
+ //
+ // NOTE: The implementation assumes that the ids of nodes passed to
+ // this method are dense and zero-based; the memory used will be linear in
+ // the largest node ID.
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status ColocateAllNodes() {
+ // This maps from a colocation group identifier to the 'root' of that
+ // colocation group. Note that the keys in this map are StringPiece; the
+ // actual strings are stored under the NodeDef. The lifetime of this map
+ // is limited to this ColocateAllNodes() method, and no part of the
+ // NodeDef trees are changed during the lifetime of this method, so using
+ // StringPiece as a key is safe.
+ //
+ // Also, as a further optimization, we remove the "loc:@" prefix from
+ // "class" attribute values, when they are used as keys in this table.
+ // This allows us to use StringPiece values that refer to substrings of
+ // 'string' values stored in NodeDef attribute lists, as well as StringPiece
+ // values that refer to 'string' values from NodeDef::name(), without
+ // performing any string allocations.
+ std::unordered_map<StringPiece, const Node*, StringPiece::Hasher>
+ colocation_group_root;
+
+ for (Node* node : graph_->nodes()) {
+ if (!node->IsOp()) {
+ continue;
+ }
+
+ // When adding the node, identify whether it is part of a
+ // colocation group.
+
+ // This code is effectively the equivalent of GetNodeAttr() for a string
+ // array, but it avoids all internal allocations (the allocation of the
+ // backing store of the std::vector<string> as well as the copies of the
+ // strings within it). Instead, we combine the query of the colocation
+ // attribute with the calls to ColocateNodeToGroup.
+ bool found_spec = false;
+ const AttrValue* attr_value =
+ node->attrs().Find(kColocationAttrNameStringPiece);
+ if (attr_value != nullptr && attr_value->has_list()) {
+ for (const string& class_spec : attr_value->list().s()) {
+ StringPiece spec(class_spec);
+ if (spec.Consume(kColocationGroupPrefixStringPiece)) {
+ found_spec = true;
+ TF_RETURN_IF_ERROR(
+ ColocateNodeToGroup(&colocation_group_root, node, spec));
+ }
+ }
+ }
+
+ if (!found_spec) {
+ // If the node does not specify a colocation group, then use the
+ // name of this node as the colocation group.
+ TF_RETURN_IF_ERROR(
+ ColocateNodeToGroup(&colocation_group_root, node, node->name()));
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status ColocateNodeToGroup(
+ std::unordered_map<StringPiece, const Node*, StringPiece::Hasher>*
+ colocation_group_root,
+ Node* node, StringPiece colocation_group) {
+ const Node*& root_node = (*colocation_group_root)[colocation_group];
+ if (root_node == nullptr) {
+ // This is the first node of the colocation group, so
+ // designate this node as the 'root' of that colocation group.
+ root_node = node;
+ } else {
+ // Try to colocate the node with the root. If there is an
+ // error, return it.
+ Status s = ColocateNodes(*node, *root_node);
+ if (!s.ok()) {
+ return AttachDef(s, *node);
+ }
+ }
+ return Status::OK();
+ }
+
+ // Merge the (possibly disjoint) sets containing nodes "x" and
+ // "y". Returns OK if the all nodes in the union of these sets can
+ // be placed on the same device type.
+ //
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status ColocateNodes(const Node& x, const Node& y) {
+ int x_root = FindRoot(x.id());
+ int y_root = FindRoot(y.id());
+ return ColocateNodes(x, x_root, y, y_root);
+ }
+
+ // This overload of ColocateNodes() allows a caller to provide the root node
+ // ids for the two nodes. For large graphs, this noticeably reduces the
+ // graph load time.
+ Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) {
+ if (x_root == y_root) {
+ return Status::OK();
+ }
+
+ DCHECK_EQ(x_root, FindRoot(x.id()));
+ DCHECK_EQ(y_root, FindRoot(y.id()));
+
+ Member& x_root_member = members_[x_root];
+ Member& y_root_member = members_[y_root];
+
+ // Merge the sets by swinging the parent pointer of the smaller
+ // tree to point to the root of the larger tree. Together with
+ // path compression in ColocationGraph::FindRoot, this ensures
+ // that we do not experience pathological performance on graphs
+ // such as chains.
+ int new_root, old_root;
+ if (x_root_member.rank < y_root_member.rank) {
+ // The tree rooted at x_root is shallower, so connect it to
+ // y_root. The rank of y_root is unchanged because its new
+ // child has strictly less rank.
+ x_root_member.parent = y_root;
+ new_root = y_root;
+ old_root = x_root;
+ } else if (x_root_member.rank > y_root_member.rank) {
+ // The tree rooted at y_root is shallower, so connect it to
+ // x_root. The rank of x_root is unchanged because its new
+ // child has strictly less rank.
+ y_root_member.parent = x_root;
+ new_root = x_root;
+ old_root = y_root;
+ } else {
+ // Both trees have the same rank, so break the tie by choosing
+ // x_root as the new root.
+ y_root_member.parent = x_root;
+ // Increment the rank of the tree rooted at x_root, because it
+ // is now strictly deeper than before.
+ ++x_root_member.rank;
+ new_root = x_root;
+ old_root = y_root;
+ }
+
+ Member& new_root_member = members_[new_root];
+ Member& old_root_member = members_[old_root];
+
+ // Merge the partial device specifications, and ensure that they are
+ // compatible. NULL options_ is treated as allowing soft placement.
+ // TODO(mrry): Consider enriching the error message by pointing
+ // out which nodes have the explicit partial device
+ // specifications that caused this conflict.
+ Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name,
+ old_root_member.device_name,
+ allow_soft_placement_);
+ if (!s.ok()) {
+ return errors::InvalidArgument("Cannot colocate nodes '", x.name(),
+ "' and '", y.name(), ": ",
+ s.error_message());
+ }
+
+ // Ensure that the common root has at least one supported device
+ // type, by computing the intersection of
+ // new_root_member.supported_device_types and
+ // old_root_member.supported_device_types.
+ MergeSupportedDevices(&new_root_member.supported_device_types,
+ old_root_member.supported_device_types);
+ if (new_root_member.supported_device_types.empty()) {
+ return errors::InvalidArgument(
+ "Cannot colocate nodes '", x.name(), "' and '", y.name(),
+ "' because no device type supports both of those nodes and the "
+ "other nodes colocated with them.",
+ DebugInfo(x_root), DebugInfo(y_root));
+ }
+
+ return Status::OK();
+ }
+
+ // For the given node, subject to the constraints previously given
+ // to this ColocationGraph, set its assigned_device_name. Returns OK
+ // if a satisfying device can be found, otherwise an error.
+ //
+ // Note: This method returns a pointer to a field within members_.
+ // The caller must not use the returned pointer after there is any possibility
+ // that the members_[i].possible_devices field has been modified.
+ Status GetDevicesForNode(Node* node,
+ std::vector<Device*>** possible_devices) {
+ *possible_devices = nullptr;
+ const int node_root = FindRoot(node->id());
+ if (!members_[node_root].possible_devices.empty()) {
+ *possible_devices = &members_[node_root].possible_devices;
+ return Status::OK();
+ }
+
+ // We have not yet computed the possible devices for the
+ // colocated node set containing 'node', so we do so now using the
+ // constraints on the root node.
+
+ // "devices" will contain the set of feasible placements for the
+ // colocated node set containing 'node'.
+ std::vector<Device*> devices;
+ if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) {
+ // The root node has a (possibly partial) device
+ // specification, so enumerate the physical devices that
+ // conform to it.
+ device_set_->FindMatchingDevices(members_[node_root].device_name,
+ &devices);
+
+ if (!devices.empty()) {
+ // Filter devices into those that are compatible with the root
+ // node (and its children).
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+
+ // Perform soft placement if allow_soft_placement_ is set.
+ if (devices.empty() && allow_soft_placement_) {
+ // The soft_device_name is the same as the node's device name
+ // without specifying the device type or ID.
+ DeviceNameUtils::ParsedName soft_device_name =
+ members_[node_root].device_name;
+ soft_device_name.type.clear();
+ soft_device_name.has_type = false;
+ soft_device_name.has_id = false;
+ device_set_->FindMatchingDevices(soft_device_name, &devices);
+ if (!devices.empty()) {
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+ }
+
+ if (devices.empty()) {
+ // Return an error when a physical device that matches an explicit
+ // device specification is not found. This ensures that we don't
+ // assign a node to GPU when the user wanted to force it on CPU.
+ string debug_info = DebugInfo(node_root);
+
+ DeviceNameUtils::ParsedName specified_device_name;
+ if (DeviceNameUtils::ParseFullName(node->requested_device(),
+ &specified_device_name) &&
+ specified_device_name == members_[node_root].device_name) {
+ // The specified device and merged set device match, and
+ // will appear in the GraphDef (for debugging), so just
+ // print the specified device.
+ std::vector<Device*> devices_matching_nodedef;
+ device_set_->FindMatchingDevices(specified_device_name,
+ &devices_matching_nodedef);
+ if (devices_matching_nodedef.empty()) {
+ // Sometimes it is almost impossible to understand the problem
+ // without a list of available devices.
+ std::vector<string> device_names;
+ for (const Device* device : device_set_->devices()) {
+ device_names.push_back(device->name());
+ }
+ std::sort(device_names.begin(), device_names.end());
+
+ return errors::InvalidArgument(
+ "Operation was explicitly assigned to ",
+ node->requested_device(), " but available devices are [ ",
+ str_util::Join(device_names, ", "), " ]. Make sure ",
+ "the device specification refers to a valid device.");
+ } else if (specified_device_name.has_type) {
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->requested_device(), "' because no supported kernel for ",
+ specified_device_name.type, " devices is available.",
+ debug_info);
+ } else {
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->requested_device(), debug_info);
+ }
+ } else {
+ // The specified device may be a valid device but the
+ // merged set device is different, so print both.
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->requested_device(),
+ "' because the node was colocated with a group of nodes that "
+ "required incompatible device '",
+ DeviceNameUtils::ParsedNameToString(
+ members_[node_root].device_name),
+ "'", debug_info);
+ }
+ }
+ } else {
+ // The device is completely unspecified, so enumerate the devices that
+ // support all of the nodes in the set.
+ if (device_set_->devices().empty()) {
+ return errors::Internal("No devices are registered");
+ }
+ devices = FilterSupportedDevices(
+ device_set_->devices(), members_[node_root].supported_device_types);
+
+ if (devices.empty()) {
+ return errors::InvalidArgument(
+ "Node had no OpKernel registered to support this operation: ",
+ "Operation was ", node->type_string(), " and inputs were ",
+ DataTypeVectorString(node->input_types()), DebugInfo(node_root));
+ }
+ }
+
+ // Cache the result of the possible devices for this node group.
+ members_[node_root].possible_devices = std::move(devices);
+ *possible_devices = &members_[node_root].possible_devices;
+ return Status::OK();
+ }
+
+ Status InitializeMembers() {
+ for (Node* node : graph_->nodes()) {
+ if (!node->IsOp()) {
+ continue;
+ }
+ Status status = InitializeMember(*node, &members_[node->id()]);
+ if (!status.ok()) {
+ return AttachDef(status, *node);
+ }
+ }
+ return Status::OK();
+ }
+
+ // Represents a node in the disjoint node set forest, and the
+ // accumulated constraints on the device used by that node.
+ struct Member {
+ Member() = default;
+ // The id of the node that is the parent of this one, or its own
+ // id if it is a root. parent <= 0 indicates that this member is invalid.
+ int parent = -1;
+
+ // A proxy for the depth of the tree that is used to prefer
+ // connecting smaller trees to larger trees when merging disjoint
+ // sets.
+ int rank = 0;
+
+ // The intersection of all device types supported by this node,
+ // and those of all of its children, in priority order
+ // of the preferred device.
+ DeviceTypeVector supported_device_types;
+
+ // The merged form of the device requested for this node, with
+ // those of all of its children.
+ DeviceNameUtils::ParsedName device_name;
+
+ // If this node is a root, stores a list of Devices to which this node
+ // and all of its children have been assigned, or nullptr if this
+ // has not yet been computed.
+ std::vector<Device*> possible_devices;
+ };
+
+ // Returns debugging info for the node referred to by 'node_root'.
+ string DebugInfo(const int node_root) {
+ string text(
+ "\nColocation Debug Info:\n"
+ "Colocation group had the following types and devices: ");
+
+ // If this node is part of a colocation group, then we want to
+ // collect the mapping of ops to supported devices, so that
+ // the user can see why an unsatisfiable placement occurred.
+
+ std::unordered_map<string, string> type_to_devices;
+ int num_nodes_found = 0;
+
+ for (const Node* node : graph_->nodes()) {
+ if (!node->IsOp()) {
+ continue;
+ }
+ int id = node->id();
+ if (FindRoot(id) != node_root) {
+ continue;
+ }
+ ++num_nodes_found;
+ const string& op_type = node->type_string();
+ string devices_registered;
+ for (const auto& device_type : members_[id].supported_device_types) {
+ strings::StrAppend(&devices_registered, DeviceTypeString(device_type),
+ " ");
+ }
+
+ type_to_devices[op_type] = std::move(devices_registered);
+ }
+
+ for (const auto& td : type_to_devices) {
+ strings::StrAppend(&text, "\n", td.first, ": ", td.second);
+ }
+
+ if (num_nodes_found <= 1) {
+ text.clear();
+ }
+ return text;
+ }
+
+ Status InitializeMember(const Node& node, Member* member) {
+ const int id = node.id();
+ DCHECK_GE(id, 0);
+ member->parent = id;
+ TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+ device_types_, node.def(), &member->supported_device_types));
+
+ if (node.has_assigned_device_name()) {
+ // This node has already been assigned to a device, so we
+ // respect this placement, after sanity-checking it. The
+ // device_name and supported_device_types for this node reflect
+ // the assigned device, so any nodes colocated with this node
+ // will be assigned to the same device (assuming this is
+ // possible).
+ // NOTE: Since any assignment must have been performed by
+ // the TensorFlow runtime, we consider errors in this branch to
+ // be INTERNAL.
+ const string& assigned_device_name = node.assigned_device_name();
+ if (!DeviceNameUtils::ParseFullName(assigned_device_name,
+ &member->device_name)) {
+ return errors::Internal("Malformed assigned device '",
+ assigned_device_name, "'");
+ }
+ const Device* assigned_device =
+ device_set_->FindDeviceByName(assigned_device_name);
+ if (assigned_device == nullptr) {
+ return errors::Internal("Assigned device '", assigned_device_name,
+ "' does not match any device");
+ }
+
+ for (const DeviceType& d : member->supported_device_types) {
+ if (DeviceType(assigned_device->attributes().device_type()) == d) {
+ return Status::OK();
+ }
+ }
+
+ return errors::Internal("Assigned device '", assigned_device_name,
+ "' does not have registered OpKernel support "
+ "for ",
+ node.type_string());
+ } else {
+ // This node has not yet been assigned to a device, so we
+ // calculate any constraints due to the set of registered
+ // kernels and any (partial) user-provided device specification
+ // in the NodeDef.
+
+ // If no kernels are registered for this op type, fail with an error.
+ if (member->supported_device_types.empty()) {
+ std::set<string> registered_device_types;
+ for (Device* d : device_set_->devices()) {
+ registered_device_types.insert(d->device_type());
+ }
+ return errors::InvalidArgument(
+ "No OpKernel was registered to support Op '", node.type_string(),
+ "' with these attrs. Registered devices: [",
+ str_util::Join(registered_device_types, ","),
+ "], Registered kernels:\n",
+ KernelsRegisteredForOp(node.type_string()));
+ }
+
+ // If the NodeDef contains a device, then we interpret it as a
+ // (partial) device specification.
+ if (!node.requested_device().empty()) {
+ // The user has specified a device in the NodeDef, try to find a
+ // valid device matching their specification in the set of
+ // devices.
+ // NOTE: The full name may specify a device that is not in
+ // n.supported_device_types(), but we check that in AssignDevice().
+ if (!DeviceNameUtils::ParseFullName(node.requested_device(),
+ &member->device_name)) {
+ return errors::InvalidArgument("Malformed device specification '",
+ node.requested_device(), "'");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ // Updates target to contain the intersection of the device types in
+ // "target" and "other".
+ static void MergeSupportedDevices(DeviceTypeVector* target,
+ const DeviceTypeVector& other) {
+ DeviceTypeVector temp = *target;
+ target->clear();
+
+ // Iterate in priority order.
+ for (const DeviceType& device_type : temp) {
+ bool found = false;
+ for (const DeviceType& other_device_type : other) {
+ if (device_type == other_device_type) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ target->push_back(device_type);
+ }
+ }
+ }
+
+ // Returns the root node of the disjoint tree to which the node with the
+ // given id is connected.
+ int FindRoot(int node_id) {
+ Member& member = members_[node_id];
+
+ int parent = member.parent;
+ DCHECK_GE(parent, 0);
+
+ if (parent != node_id) {
+ // NOTE: Compress paths from node_id to its root, so that future
+ // calls to FindRoot and ColocateNodes are more efficient.
+ int root = FindRoot(parent);
+ if (parent != root) {
+ parent = root;
+ member.parent = root;
+ }
+ }
+
+ DCHECK_GE(parent, 0);
+ return parent;
+ }
+
+ Graph* const graph_; // Not owned.
+ std::vector<Member> members_;
+ const DeviceSet* device_set_; // Not owned.
+ const std::vector<DeviceType> device_types_;
+ const bool allow_soft_placement_;
+};
+
+// Returns true if the node has no inputs and produces outputs
+// that are consumed by a single node.
+//
+// TODO(vrv): Currently this handles only nodes with one output, but
+// this could be extended to handle the case where a node has many
+// outputs that are connected to nodes in the same colocation group.
+bool IsGeneratorNode(const Node* node) {
+ return node->num_inputs() == 0 && node->num_outputs() == 1 &&
+ !IsRefType(node->output_type(0));
+}
+
+} // namespace
+
+Placer::Placer(Graph* graph, const DeviceSet* devices,
+ const SessionOptions* options)
+ : graph_(graph),
+ devices_(devices),
+ options_(options),
+ log_device_placement_(options != nullptr &&
+ options->config.log_device_placement()) {}
+
+Placer::Placer(Graph* graph, const DeviceSet* devices)
+ : Placer(graph, devices, nullptr) {}
+
+Placer::~Placer() {}
+
+Status Placer::Run() {
+ if (devices_->devices().empty()) {
+ return errors::FailedPrecondition("No devices are registered");
+ }
+
+ ColocationGraph colocation_graph(
+ graph_, devices_,
+ options_ == nullptr || options_->config.allow_soft_placement());
+
+ TF_RETURN_IF_ERROR(colocation_graph.InitializeMembers());
+
+ // 1. First add all of the nodes. Note that steps (1) and (2)
+ // requires two passes over the nodes because the graph (and hence
+ // the constraints) may not be acyclic.
+ TF_RETURN_IF_ERROR(colocation_graph.ColocateAllNodes());
+
+ // 2. Enumerate the constraint edges, and use them to update the disjoint
+ // node set.
+
+ // If `node` has an input edge with reference type, add an
+ // edge from the source of that edge to `node`.
+ for (const Edge* edge : graph_->edges()) {
+ if (edge->IsControlEdge()) {
+ continue;
+ }
+ Node* src = edge->src();
+ Node* dst = edge->dst();
+ DataType input_type = dst->input_type(edge->dst_input());
+ if (input_type == DT_RESOURCE || IsRefType(input_type)) {
+ int src_root_id = colocation_graph.FindRoot(src->id());
+ int dst_root_id = colocation_graph.FindRoot(dst->id());
+ auto& src_root = colocation_graph.members_[src_root_id];
+ auto& dst_root = colocation_graph.members_[dst_root_id];
+ // If both the source node and this node have partially
+ // specified a device, then 'node's device should be
+ // cleared: the reference edge forces 'node' to be on the
+ // same device as the source node.
+ const auto& source_parsed_name = src_root.device_name;
+ const auto& dest_parsed_name = dst_root.device_name;
+ if (DeviceNameUtils::HasSomeDetails(source_parsed_name) &&
+ DeviceNameUtils::HasSomeDetails(dest_parsed_name)) {
+ // Ignore a specified device for 'dst' if the two names were
+ // incompatible.
+ if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name,
+ dest_parsed_name)) {
+ if (log_device_placement_) {
+ LOG(INFO) << "Ignoring device specification "
+ << DeviceNameUtils::ParsedNameToString(dest_parsed_name)
+ << " for node '" << dst->name()
+ << "' because the input edge from '" << src->name()
+ << "' is a reference connection and already has a device "
+ "field set to "
+ << DeviceNameUtils::ParsedNameToString(
+ source_parsed_name);
+ }
+
+ // Make 'dst' colocated with the source
+ dst_root.device_name = source_parsed_name;
+ } else {
+ bool source_subset_of_dest = DeviceNameUtils::IsSpecification(
+ source_parsed_name, dest_parsed_name);
+ bool dest_subset_of_source = DeviceNameUtils::IsSpecification(
+ dest_parsed_name, source_parsed_name);
+
+ if (source_subset_of_dest && !dest_subset_of_source) {
+ src_root.device_name = dest_parsed_name;
+ } else {
+ dst_root.device_name = source_parsed_name;
+ }
+ }
+ }
+
+ Status status =
+ colocation_graph.ColocateNodes(*src, src_root_id, *dst, dst_root_id);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Nodes were connected by a "
+ "reference connection (requiring them to "
+ "be on the same device), but the two nodes "
+ "were assigned two different devices: ",
+ status.error_message()),
+ *dst);
+ }
+ }
+ }
+
+ // 3. For each node, assign a device based on the constraints in the
+ // disjoint node set.
+ std::vector<Node*> second_pass;
+ for (Node* node : graph_->op_nodes()) {
+ // The graph may have come pre-populated by the framework with assigned
+ // devices (e.g., for stateful placements), so the placer should not try to
+ // place nodes that are already placed.
+ if (node->has_assigned_device_name()) {
+ LogDeviceAssignment(node);
+ continue;
+ }
+
+ // Heuristic A: prefer to place "generators" with their only
+ // consumers.
+ //
+ // If this is a node with no inputs and one output, we save
+ // this for a second pass, so that the consumer's placement
+ // is chosen.
+ if (IsGeneratorNode(node)) {
+ second_pass.push_back(node);
+ continue;
+ }
+
+ std::vector<Device*>* devices;
+ Status status = colocation_graph.GetDevicesForNode(node, &devices);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation '",
+ node->name(), "': ", status.error_message()),
+ *node);
+ }
+
+ // Returns the first device in sorted devices list so we will always
+ // choose the same device.
+ //
+ // TODO(vrv): Factor this assignment out into a pluggable
+ // algorithm, so that Placer is responsible for enforcing
+ // preconditions and we can experiment with other algorithms when
+ // given a choice of devices. Once we have a better idea of the
+ // types of heuristics we want to use and the information needed
+ // to perform good placement we can add an interface for this.
+ int assigned_device = -1;
+
+ // Heuristic B: If the node only operates on metadata, not data,
+ // then it is desirable to place that metadata node with its
+ // input.
+ if (IsMetadata(node)) {
+ // Make sure that the input device type is in the list of supported
+ // device types for this node.
+ const Node* input = (*node->in_edges().begin())->src();
+ // TODO(vrv): if the input is empty, consider postponing this
+ // node's assignment to the second pass, so that we handle the
+ // case where a metadata node's input comes from a backedge
+ // of a loop.
+ if (CanAssignToDevice(input->assigned_device_name(), *devices)) {
+ assigned_device = input->assigned_device_name_index();
+ }
+ }
+
+ // Provide the default, if necessary.
+ if (assigned_device == -1) {
+ assigned_device = graph_->InternDeviceName((*devices)[0]->name());
+ }
+
+ AssignAndLog(assigned_device, node);
+ }
+
+ // 4. Perform a second pass assignment for those nodes explicitly
+ // skipped during the first pass.
+ for (Node* node : second_pass) {
+ std::vector<Device*>* devices;
+ Status status = colocation_graph.GetDevicesForNode(node, &devices);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation '",
+ node->name(), "': ", status.error_message()),
+ *node);
+ }
+
+ int assigned_device = -1;
+
+ // Heuristic A application.
+ if (IsGeneratorNode(node)) {
+ const Node* output = (*node->out_edges().begin())->dst();
+ int output_device_name = output->assigned_device_name_index();
+
+ const bool consumers_on_same_device = std::all_of(
+ node->out_edges().begin(), node->out_edges().end(),
+ [output_device_name](const Edge* e) {
+ return e->dst()->assigned_device_name_index() == output_device_name;
+ });
+
+ if (consumers_on_same_device &&
+ CanAssignToDevice(output->assigned_device_name(), *devices)) {
+ assigned_device = output_device_name;
+ }
+ }
+
+ // Provide the default, if necessary.
+ if (assigned_device == -1) {
+ assigned_device = graph_->InternDeviceName((*devices)[0]->name());
+ }
+
+ AssignAndLog(assigned_device, node);
+ }
+
+ return Status::OK();
+}
+
+bool Placer::CanAssignToDevice(const string& candidate_device_name,
+ const std::vector<Device*>& devices) const {
+ if (!candidate_device_name.empty()) {
+ // 'devices' lists the set of devices that the placer or the user has
+ // constrained the operation to. "candidate_device_name" must
+ // refer to a concrete Device that is in the list of 'devices'.
+ const Device* other_device =
+ devices_->FindDeviceByName(candidate_device_name);
+ if (std::find(devices.begin(), devices.end(), other_device) !=
+ devices.end()) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+void Placer::AssignAndLog(int assigned_device, Node* node) const {
+ node->set_assigned_device_name_index(assigned_device);
+ LogDeviceAssignment(node);
+}
+
+void Placer::LogDeviceAssignment(const Node* node) const {
+ // Log placement if log_device_placement is set.
+ if (log_device_placement_) {
+ printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(),
+ node->assigned_device_name().c_str());
+ LOG(INFO) << node->name() << ": "
+ << "(" << node->type_string() << ")"
+ << node->assigned_device_name();
+ }
+}
+
+} // namespace tensorflow