aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/simple_placer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/simple_placer.cc')
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc559
1 files changed, 559 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
new file mode 100644
index 0000000000..1cd1db29db
--- /dev/null
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -0,0 +1,559 @@
+#include "tensorflow/core/common_runtime/simple_placer.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Returns a list of devices sorted by name from 'devices' whose type is in
+// 'supported_device_types'. This function searches in order of the device
+// types in 'supported_device_types' and returns the *first* subset of devices
+// that match.
+//
+// For example, if suported_device_types contains {GPU, CPU} and
+// 'devices' contains CPU and GPU devices, the returned vector will
+// include *only* GPU devices, since that is higher in the priority
+// order in 'supported_device_types'.
+std::vector<Device*> FilterSupportedDevices(
+ const std::vector<Device*>& devices,
+ const DeviceTypeVector& supported_device_types) {
+ std::vector<Device*> filtered_devices;
+ auto device_sort = [](const Device* a, const Device* b) {
+ return a->name() < b->name();
+ };
+ for (DeviceType d : supported_device_types) {
+ for (Device* device : devices) {
+ if (DeviceType(device->attributes().device_type()) == d) {
+ filtered_devices.emplace_back(device);
+ }
+ }
+
+ // If there are any devices under this device type, return this
+ // subset.
+ if (!filtered_devices.empty()) {
+ std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
+ return filtered_devices;
+ }
+ }
+
+ std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
+ return filtered_devices;
+}
+
+bool HasColocatedNodeName(const Node& node) {
+ return StringPiece(node.def().device()).starts_with("@");
+}
+
+Status ParseColocatedNodeName(const Node& node,
+ string* out_colocated_node_name) {
+ StringPiece device(node.def().device());
+ if (!device.Consume("@")) {
+ return errors::InvalidArgument("Malformed colocated node name: '", device,
+ "'");
+ }
+ // TODO(mrry): Validate that the node name is a valid node name.
+ *out_colocated_node_name = device.ToString();
+ return Status::OK();
+}
+
+// This class maintains the connected components of a colocation
+// constraint graph, and uses this information to assign a satisfying
+// device placement to the nodes of the graph.
+//
+// The typical usage pattern is:
+//
+// Graph graph = ...;
+// DeviceSet device_set = ...;
+// ColocationGraph colocation_graph(graph, device_set);
+//
+// // Add all the nodes of graph to colocation_graph.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node));
+// }
+//
+// // Add one or more colocation constraint.
+// Node node_1 = *graph.FindNodeId(...);
+// Node node_2 = *graph.FindNodeId(...);
+// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2));
+//
+// // Assign devices based on the accumulated constraints.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node));
+// }
+//
+// The implementation uses the union-find algorithm to maintain the
+// connected components efficiently and incrementally as edges
+// (implied by ColocationGraph::ColocateNodes() invocations) are added.
+class ColocationGraph {
+ public:
+ ColocationGraph(Graph* graph, const DeviceSet* device_set,
+ const SessionOptions* options)
+ : device_set_(device_set),
+ device_types_(device_set->PrioritizedDeviceTypeList()),
+ options_(options) {
+ members_.reserve(graph->num_node_ids());
+ }
+
+ // Adds the given node to this ColocationGraph as a singleton.
+ //
+ // NOTE: The implementation assumes that the ids of nodes passed to
+ // this method are dense and zero-based; the memory used will be linear in
+ // the largest node ID.
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status AddNode(const Node& node) {
+ Member member;
+ TF_RETURN_IF_ERROR(InitializeMember(node, &member));
+ CHECK_GE(member.parent, 0);
+ members_.resize(member.parent + 1);
+ members_[member.parent] = std::move(member);
+ return Status::OK();
+ }
+
+ // Merge the (possibly disjoint) sets containing nodes "x" and
+ // "y". Returns OK if the all nodes in the union of these sets can
+ // be placed on the same device type.
+ //
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status ColocateNodes(const Node& x, const Node& y) {
+ int x_root = FindRoot(x.id());
+ int y_root = FindRoot(y.id());
+ if (x_root != y_root) {
+ // Merge the sets by swinging the parent pointer of the smaller
+ // tree to point to the root of the larger tree. Together with
+ // path compression in ColocationGraph::FindRoot, this ensures
+ // that we do not experience pathological performance on graphs
+ // such as chains.
+ int new_root, old_root;
+ if (members_[x_root].rank < members_[y_root].rank) {
+ // The tree rooted at x_root is shallower, so connect it to
+ // y_root. The rank of y_root is unchanged because its new
+ // child has strictly less rank.
+ members_[x_root].parent = y_root;
+ new_root = y_root;
+ old_root = x_root;
+ } else if (members_[x_root].rank > members_[y_root].rank) {
+ // The tree rooted at y_root is shallower, so connect it to
+ // x_root. The rank of x_root is unchanged because its new
+ // child has strictly less rank.
+ members_[y_root].parent = x_root;
+ new_root = x_root;
+ old_root = y_root;
+ } else {
+ // Both trees have the same rank, so break the tie by choosing
+ // x_root as the new root.
+ members_[y_root].parent = x_root;
+ // Increment the rank of the tree rooted at x_root, because it
+ // is now strictly deeper than before.
+ ++members_[x_root].rank;
+ new_root = x_root;
+ old_root = y_root;
+ }
+
+ // Merge the partial device specifications, and ensure that they are
+ // compatible. NULL options_ is treated as allowing soft placement.
+ // TODO(mrry): Consider enriching the error message by pointing
+ // out which nodes have the explicit partial device
+ // specifications that caused this conflict.
+ TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
+ &members_[new_root].device_name, members_[old_root].device_name,
+ options_ == nullptr || options_->config.allow_soft_placement()));
+
+ // Ensure that the common root has at least one supported device
+ // type, by computing the intersection of
+ // members_[new_root].supported_device_types and
+ // members_[old_root].supported_device_types.
+ MergeSupportedDevices(&members_[new_root].supported_device_types,
+ members_[old_root].supported_device_types);
+ if (members_[x_root].supported_device_types.size() == 0) {
+ return errors::InvalidArgument(
+ "Cannot colocate nodes '", x.name(), "' and '", y.name(),
+ "' because no device type supports both of those nodes and the "
+ "other nodes colocated with them");
+ }
+ }
+ return Status::OK();
+ }
+
+ // For the given node, subject to the constraints previously given
+ // to this ColocationGraph, set its assigned_device_name. Returns OK
+ // if a satisfying device can be found, otherwise an error.
+ Status AssignDevice(Node* node) {
+ int node_root = FindRoot(node->id());
+ if (members_[node_root].assigned_device == nullptr) {
+ // We have not yet assigned a device for the colocated node set containing
+ // n, so we do so now using the constraints on the root node.
+
+ // "devices" will contain the set of feasible placements for the
+ // colocated node set containing n.
+ std::vector<Device*> devices;
+ if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) {
+ // The root node has a (possibly partial) device
+ // specification, so enumerate the physical devices that
+ // conform to it.
+ device_set_->FindMatchingDevices(members_[node_root].device_name,
+ &devices);
+
+ if (!devices.empty()) {
+ // Filter devices into those that are compatible with the root
+ // node (and its children).
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+
+ // Perform soft placement if allow_soft_placement is set. options_
+ // being NULL is treated as allowing soft placement.
+ if (devices.empty() &&
+ (options_ == nullptr || options_->config.allow_soft_placement())) {
+ // The soft_device_name is the same as the node's device name
+ // without specifying the device type or ID.
+ DeviceNameUtils::ParsedName soft_device_name =
+ members_[node_root].device_name;
+ soft_device_name.type.clear();
+ soft_device_name.has_type = false;
+ soft_device_name.has_id = false;
+ device_set_->FindMatchingDevices(soft_device_name, &devices);
+ if (!devices.empty()) {
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+ }
+
+ if (devices.empty()) {
+ // Return an error when a physical device that matches an explicit
+ // device specification is not found. This ensures that we don't
+ // assign a node to GPU when the user wanted to force it on CPU.
+ DeviceNameUtils::ParsedName specified_device_name;
+ if (DeviceNameUtils::ParseFullName(node->def().device(),
+ &specified_device_name) &&
+ specified_device_name == members_[node_root].device_name) {
+ // The specified device and merged set device match, and
+ // will appear in the GraphDef (for debugging), so just
+ // print the specified device.
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->def().device(), "'");
+ } else {
+ // The specified device may be a valid device but the
+ // merged set device is different, so print both.
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->def().device(),
+ "' because the node was colocated with a group of nodes that "
+ "required incompatible device '",
+ DeviceNameUtils::ParsedNameToString(
+ members_[node_root].device_name),
+ "'");
+ }
+ }
+ } else {
+ // The device is completely unspecified, so enumerate the devices that
+ // support all of the nodes in the set.
+ if (device_set_->devices().empty()) {
+ return errors::Internal("No devices are registered");
+ }
+ devices = FilterSupportedDevices(
+ device_set_->devices(), members_[node_root].supported_device_types);
+
+ if (devices.empty()) {
+ return errors::InvalidArgument(
+ "Node had no OpKernel registered to support this operation: ",
+ "Operation was ", node->type_string(), " and inputs were ",
+ DataTypeVectorString(node->input_types()));
+ }
+ }
+
+ // Returns the first device in sorted devices list so we will always
+ // choose the same device.
+ members_[node_root].assigned_device = devices[0];
+ }
+ node->set_assigned_device_name(members_[node_root].assigned_device->name());
+
+ // Log placement if log_device_placement is set.
+ if (options_ && options_->config.log_device_placement()) {
+ printf("%s: %s\n", node->name().c_str(),
+ node->assigned_device_name().c_str());
+ LOG(INFO) << node->name() << ": " << node->assigned_device_name();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // Represents a node in the disjoint node set forest, and the
+ // accumulated constraints on the device used by that node.
+ struct Member {
+ Member() = default;
+ // The id of the node that is the parent of this one, or its own
+ // id if it is a root. parent <= 0 indicates that this member is invalid.
+ int parent = -1;
+ // A proxy for the depth of the tree that is used to prefer
+ // connecting smaller trees to larger trees when merging disjoint
+ // sets.
+ int rank = 0;
+ // The intersection of all device types supported by this node,
+ // and those of all of its children, in priority order
+ // of the preferred device.
+ DeviceTypeVector supported_device_types;
+ // The merged form of the device requested for this node, with
+ // those of all of its children.
+ DeviceNameUtils::ParsedName device_name;
+ // If this node is a root, stores the Device to which this node
+ // and all of its children have been assigned, or nullptr if this
+ // has not yet been computed by GetAssignedDevice().
+ Device* assigned_device = nullptr;
+ };
+
+ Status InitializeMember(const Node& node, Member* member) {
+ const int id = node.id();
+ if (id < 0) {
+ return errors::InvalidArgument("Node id was not positive: ", id);
+ }
+ member->parent = id;
+ TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+ device_types_, node.def(), &member->supported_device_types));
+
+ if (!node.assigned_device_name().empty()) {
+ // This node has already been assigned to a device, so we
+ // respect this placement, after sanity-checking it. The
+ // device_name and supported_device_types for this node reflect
+ // the assigned device, so any nodes colocated with this node
+ // will be assigned to the same device (assuming this is
+ // possible).
+ // NOTE: Since any assignment must have been performed by
+ // the TensorFlow runtime, we consider errors in this branch to
+ // be INTERNAL.
+ if (!DeviceNameUtils::ParseFullName(node.assigned_device_name(),
+ &member->device_name)) {
+ return errors::Internal("Malformed assigned device '",
+ node.assigned_device_name(), "'");
+ }
+ std::vector<Device*> devices;
+ const Device* assigned_device =
+ device_set_->FindDeviceByName(node.assigned_device_name());
+ if (assigned_device == nullptr) {
+ return errors::Internal("Assigned device '",
+ node.assigned_device_name(),
+ "' does not match any device");
+ }
+
+ for (DeviceType d : member->supported_device_types) {
+ if (DeviceType(assigned_device->attributes().device_type()) == d) {
+ return Status::OK();
+ }
+ }
+
+ return errors::Internal("Assigned device '", node.assigned_device_name(),
+ "' does not have registered OpKernel support "
+ "for ",
+ node.def().op());
+ } else {
+ // This node has not yet been assigned to a device, so we
+ // calculate any constraints due to the set of registered
+ // kernels and any (partial) user-provided device specification
+ // in the NodeDef.
+
+ // If no kernels are registered for this op type, fail with an error.
+ if (member->supported_device_types.empty()) {
+ return errors::InvalidArgument(
+ "No OpKernel was registered to support "
+ "Op '",
+ node.def().op(), "' with these attrs");
+ }
+
+ // If the NodeDef contains a device that is *not* a colocated node name
+ // (i.e. it does not begin with '@') then we interpret it as a (partial)
+ // device specification.
+ string colocated_node_name;
+ if (!node.def().device().empty() && !HasColocatedNodeName(node)) {
+ // The user has specified a device in the NodeDef, try to find a
+ // valid device matching their specification in the set of
+ // devices.
+ // NOTE: The full name may specify a device that is not in
+ // n.supported_device_types(), but we check that in AssignDevice().
+ if (!DeviceNameUtils::ParseFullName(node.def().device(),
+ &member->device_name)) {
+ return errors::InvalidArgument("Malformed device specification '",
+ node.def().device(), "'");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ // Updates target to contain the intersection of the device types in
+ // "target" and "other".
+ static void MergeSupportedDevices(DeviceTypeVector* target,
+ const DeviceTypeVector& other) {
+ DeviceTypeVector temp = *target;
+ target->clear();
+
+ // Iterate in priority order.
+ for (DeviceType device_type : temp) {
+ bool found = false;
+ for (DeviceType other_device_type : other) {
+ if (device_type == other_device_type) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ target->push_back(device_type);
+ }
+ }
+ }
+
+ // Returns the root node of the disjoint tree to which the node with the
+ // given id is connected.
+ int FindRoot(int node_id) {
+ DCHECK_GE(members_[node_id].parent, 0);
+ if (members_[node_id].parent != node_id) {
+ // NOTE: Compress paths from node_id to its root, so that future
+ // calls to FindRoot and ColocateNodes are more efficient.
+ members_[node_id].parent = FindRoot(members_[node_id].parent);
+ }
+ return members_[node_id].parent;
+ }
+
+ std::vector<Member> members_;
+ const DeviceSet* device_set_; // Not owned.
+ const std::vector<DeviceType> device_types_;
+ const SessionOptions* options_; // Not owned;
+};
+
+} // namespace
+
+SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map,
+ const SessionOptions* options)
+ : graph_(graph),
+ devices_(devices),
+ name_to_id_map_(name_to_id_map),
+ options_(options) {}
+
+SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map)
+ : graph_(graph), devices_(devices), name_to_id_map_(name_to_id_map) {
+ options_ = nullptr;
+}
+
+SimplePlacer::~SimplePlacer() {}
+
+Status SimplePlacer::Run() {
+ if (devices_->devices().empty()) {
+ return errors::FailedPrecondition("No devices are registered");
+ }
+
+ ColocationGraph colocation_graph(graph_, devices_, options_);
+ Status status;
+
+ // 1. First add all of the nodes. Note that steps (1) and (2)
+ // requires two passes over the nodes because the graph (and hence
+ // the constraints) may not be acyclic.
+ for (Node* node : graph_->nodes()) {
+ // Skip the source and sink nodes.
+ if (!node->IsOp()) {
+ continue;
+ }
+ status = colocation_graph.AddNode(*node);
+ if (!status.ok()) return AttachDef(status, node->def());
+ }
+
+ // 2. Enumerate the constraint edges, and use them to update the disjoint
+ // node set.
+ for (Node* node : graph_->nodes()) {
+ if (!node->IsOp()) {
+ continue;
+ }
+
+ // 2(a). If node n specifies a colocation constraint as its device name,
+ // add an edge from the colocated node to n.
+ if (HasColocatedNodeName(*node)) {
+ string colocated_node_name;
+ status = ParseColocatedNodeName(*node, &colocated_node_name);
+ if (!status.ok()) {
+ return AttachDef(status, node->def());
+ }
+ Node* colocated_node;
+ status = GetNodeByName(colocated_node_name, &colocated_node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Colocated node named in device '",
+ colocated_node_name, "' does not exist"),
+ node->def());
+ }
+ status = colocation_graph.ColocateNodes(*colocated_node, *node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument(
+ "Cannot satisfy colocation constraint named in device '",
+ colocated_node_name, "': ", status.error_message()),
+ node->def());
+ }
+ }
+
+ // 2(b). If `node` has an input edge with reference type, add an
+ // edge from the source of that edge to `node`.
+ for (const auto& edge : node->in_edges()) {
+ if (!edge->IsControlEdge() &&
+ IsRefType(node->input_type(edge->dst_input()))) {
+ status = colocation_graph.ColocateNodes(*edge->src(), *node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot satisfy colocation constraint "
+ "implied by reference connection: ",
+ status.error_message()),
+ node->def());
+ }
+ }
+ }
+ }
+
+ // 3. For each node, assign a device based on the constraints in the
+ // disjoint node set.
+ for (Node* node : graph_->nodes()) {
+ // Skip the source and sink nodes.
+ if (!node->IsOp()) {
+ continue;
+ }
+ // Skip nodes that already have an assigned name.
+ if (!node->assigned_device_name().empty()) {
+ continue;
+ }
+
+ status = colocation_graph.AssignDevice(node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device to node '",
+ node->name(), "': ", status.error_message()),
+ node->def());
+ }
+ }
+ return Status::OK();
+}
+
+Status SimplePlacer::GetNodeByName(const string& name, Node** out_node) const {
+ NodeNameToIdMap::const_iterator iter = name_to_id_map_->find(name);
+ if (iter != name_to_id_map_->end()) {
+ *out_node = graph_->FindNodeId(iter->second);
+ if (*out_node) {
+ return Status::OK();
+ }
+ }
+ return errors::NotFound(name);
+}
+
+} // namespace tensorflow