diff options
Diffstat (limited to 'tensorflow/core/common_runtime/placer.cc')
-rw-r--r-- | tensorflow/core/common_runtime/placer.cc | 90 |
1 files changed, 79 insertions, 11 deletions
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 86851c2c07..6781c87f6c 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/status_util.h" namespace tensorflow { @@ -628,6 +629,40 @@ class ColocationGraph { return parent; } + // Ensures that the devices of 'dst's resource and reference match the device + // specified for 'src', which is an input of 'dst' with a partially or fully + // specified device. + Status VerifyResourceAndRefInputsCanBeColocated( + const Node* dst, const Node* src, + const DeviceNameUtils::ParsedName& src_parsed_name) { + std::vector<const Edge*> edges; + TF_RETURN_IF_ERROR(dst->input_edges(&edges)); + for (const Edge* edge : edges) { + DataType input_type = dst->input_type(edge->dst_input()); + if (input_type == DT_RESOURCE || IsRefType(input_type)) { + const Node* input_node = edge->src(); + if (input_node == src) { + continue; + } + const auto& input_root = members_[FindRoot(input_node->id())]; + const auto& input_parsed_name = input_root.device_name; + if (DeviceNameUtils::HasSomeDetails(input_parsed_name) && + !DeviceNameUtils::AreCompatibleDevNames(input_parsed_name, + src_parsed_name)) { + return AttachDef( + errors::InvalidArgument( + "Could not colocate node with its " + "resource and reference inputs; devices ", + DeviceNameUtils::ParsedNameToString(input_parsed_name), + " and ", DeviceNameUtils::ParsedNameToString(src_parsed_name), + " are not compatible."), + *dst); + } + } + } + return Status::OK(); + } + Graph* const graph_; // Not owned. std::vector<Member> members_; const DeviceSet* device_set_; // Not owned. @@ -646,6 +681,15 @@ bool IsGeneratorNode(const Node* node) { !IsRefType(node->output_type(0)); } +bool IsExemptFromResourceInputColocation(const Node* node) { + // Note: Partitioned function calls, which place and partition their + // function bodies, are exempt from this check: they forward resource and + // ref inputs to operations that are appropriately placed, instead of + // dereferencing them. + const string& op_type = node->op_def().name(); + return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall"; +} + } // namespace Placer::Placer(Graph* graph, const DeviceSet* devices, @@ -680,8 +724,8 @@ Status Placer::Run() { // 2. Enumerate the constraint edges, and use them to update the disjoint // node set. - // If `node` has an input edge with reference type, add an - // edge from the source of that edge to `node`. + // If `node` has an input edge with reference type, add an edge from the + // source of that edge to `node`. for (const Edge* edge : graph_->edges()) { if (edge->IsControlEdge()) { continue; @@ -689,7 +733,10 @@ Status Placer::Run() { Node* src = edge->src(); Node* dst = edge->dst(); DataType input_type = dst->input_type(edge->dst_input()); - if (input_type == DT_RESOURCE || IsRefType(input_type)) { + if ((input_type == DT_RESOURCE || IsRefType(input_type)) && + !IsExemptFromResourceInputColocation(dst)) { + // Colocate `src` and `dst` to maintain the invariant that nodes connected + // by reference edges are colocated. int src_root_id = colocation_graph.FindRoot(src->id()); int dst_root_id = colocation_graph.FindRoot(dst->id()); auto& src_root = colocation_graph.members_[src_root_id]; @@ -706,6 +753,9 @@ Status Placer::Run() { // incompatible. if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, dest_parsed_name)) { + TF_RETURN_IF_ERROR( + colocation_graph.VerifyResourceAndRefInputsCanBeColocated( + dst, src, source_parsed_name)); if (log_device_placement_) { LOG(INFO) << "Ignoring device specification " << DeviceNameUtils::ParsedNameToString(dest_parsed_name) @@ -773,10 +823,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef( - errors::InvalidArgument("Cannot assign a device for operation '", - node->name(), "': ", status.error_message()), - *node); + return AttachDef(errors::InvalidArgument( + "Cannot assign a device for operation ", + RichNodeName(node), ": ", status.error_message()), + *node); } // Returns the first device in sorted devices list so we will always @@ -820,10 +870,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef( - errors::InvalidArgument("Cannot assign a device for operation '", - node->name(), "': ", status.error_message()), - *node); + return AttachDef(errors::InvalidArgument( + "Cannot assign a device for operation ", + RichNodeName(node), ": ", status.error_message()), + *node); } int assigned_device = -1; @@ -889,4 +939,22 @@ void Placer::LogDeviceAssignment(const Node* node) const { } } +bool Placer::ClientHandlesErrorFormatting() const { + return options_ != nullptr && + options_->config.experimental().client_handles_error_formatting(); +} + +// Returns the node name in single quotes. If the client handles formatted +// errors, appends a formatting tag which the client will reformat into, for +// example, " (defined at filename:123)". +string Placer::RichNodeName(const Node* node) const { + string quoted_name = strings::StrCat("'", node->name(), "'"); + if (ClientHandlesErrorFormatting()) { + string file_and_line = error_format_tag(*node, "${file}:${line}"); + return strings::StrCat(quoted_name, " (defined at ", file_and_line, ")"); + } else { + return quoted_name; + } +} + } // namespace tensorflow |