From b5c1d0f8977e0f05c9aeeb9e5105500bf83972bb Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Thu, 7 Sep 2017 22:11:25 -0700 Subject: SimpleGraphExecutionState -> GraphExecutionState SimplePlacer -> Placer And clean up a couple unneeded headers. PiperOrigin-RevId: 167955883 --- tensorflow/core/BUILD | 10 +- tensorflow/core/common_runtime/direct_session.cc | 23 +- tensorflow/core/common_runtime/direct_session.h | 4 +- .../core/common_runtime/graph_execution_state.cc | 446 +++++++ .../core/common_runtime/graph_execution_state.h | 208 ++++ tensorflow/core/common_runtime/placer.cc | 880 ++++++++++++++ tensorflow/core/common_runtime/placer.h | 101 ++ tensorflow/core/common_runtime/placer_test.cc | 1283 +++++++++++++++++++ .../common_runtime/simple_graph_execution_state.cc | 447 ------- .../common_runtime/simple_graph_execution_state.h | 209 ---- tensorflow/core/common_runtime/simple_placer.cc | 881 -------------- tensorflow/core/common_runtime/simple_placer.h | 102 -- .../core/common_runtime/simple_placer_test.cc | 1285 -------------------- .../core/distributed_runtime/master_session.cc | 22 +- .../core/distributed_runtime/master_session.h | 4 +- 15 files changed, 2949 insertions(+), 2956 deletions(-) create mode 100644 tensorflow/core/common_runtime/graph_execution_state.cc create mode 100644 tensorflow/core/common_runtime/graph_execution_state.h create mode 100644 tensorflow/core/common_runtime/placer.cc create mode 100644 tensorflow/core/common_runtime/placer.h create mode 100644 tensorflow/core/common_runtime/placer_test.cc delete mode 100644 tensorflow/core/common_runtime/simple_graph_execution_state.cc delete mode 100644 tensorflow/core/common_runtime/simple_graph_execution_state.h delete mode 100644 tensorflow/core/common_runtime/simple_placer.cc delete mode 100644 tensorflow/core/common_runtime/simple_placer.h delete mode 100644 tensorflow/core/common_runtime/simple_placer_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 9319928307..4b80d2c543 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1773,12 +1773,14 @@ tf_cuda_library( "common_runtime/device_set.cc", "common_runtime/executor.cc", "common_runtime/function.cc", + "common_runtime/graph_execution_state.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", "common_runtime/local_device.cc", "common_runtime/memory_types.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", + "common_runtime/placer.cc", "common_runtime/process_function_library_runtime.cc", "common_runtime/process_util.cc", "common_runtime/renamed_device.cc", @@ -1788,8 +1790,6 @@ tf_cuda_library( "common_runtime/session_factory.cc", "common_runtime/session_options.cc", "common_runtime/session_state.cc", - "common_runtime/simple_graph_execution_state.cc", - "common_runtime/simple_placer.cc", "common_runtime/stats_publisher_interface.cc", "common_runtime/step_stats_collector.cc", "common_runtime/threadpool_device.cc", @@ -1829,8 +1829,8 @@ tf_cuda_library( "common_runtime/renamed_device.h", "common_runtime/rendezvous_mgr.h", "common_runtime/session_factory.h", - "common_runtime/simple_graph_execution_state.h", - "common_runtime/simple_placer.h", + "common_runtime/graph_execution_state.h", + "common_runtime/placer.h", "common_runtime/stats_publisher_interface.h", "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", @@ -2321,9 +2321,9 @@ tf_cc_tests( "common_runtime/device_set_test.cc", "common_runtime/optimization_registry_test.cc", "common_runtime/pending_counts_test.cc", + "common_runtime/placer_test.cc", "common_runtime/resource_variable_read_optimizer_test.cc", "common_runtime/session_test.cc", - "common_runtime/simple_placer_test.cc", "example/feature_util_test.cc", "framework/allocator_test.cc", "framework/attr_value_util_test.cc", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index a6630f38a5..d28857bb9f 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb_text.h" @@ -345,20 +344,20 @@ Status DirectSession::MaybeInitializeExecutionState( // all subsequent extensions of the graph. flib_def_.reset( new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); - SimpleGraphExecutionStateOptions options; + GraphExecutionStateOptions options; options.device_set = &device_set_; options.session_options = &options_; // TODO(mrry,suharshs): We explicitly copy `graph` so that // `MakeForBaseGraph()` can take ownership of its // contents. Previously this happened implicitly in calls to the - // `SimpleGraphExecutionState`. Other sessions call + // `GraphExecutionState`. Other sessions call // `MakeForBaseGraph` in such a way that we can destructively read // the passed-in `GraphDef`. In principle we could do the same here, // with a wider refactoring; we might revise the direct session so // that it copies the graph fewer times. GraphDef temp(graph); - TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( - &temp, options, &execution_state_)); + TF_RETURN_IF_ERROR( + GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_)); graph_created_ = true; *out_already_initialized = false; return Status::OK(); @@ -391,7 +390,7 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) { MaybeInitializeExecutionState(graph, &already_initialized)); if (already_initialized) { TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); - std::unique_ptr state; + std::unique_ptr state; TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); execution_state_.swap(state); } @@ -1280,19 +1279,19 @@ Status DirectSession::CreateGraphs( RunStateArgs* run_state_args, DataTypeVector* input_types, DataTypeVector* output_types) { mutex_lock l(graph_def_lock_); - std::unique_ptr client_graph; + std::unique_ptr client_graph; - std::unique_ptr temp_exec_state_holder; - SimpleGraphExecutionState* execution_state = nullptr; + std::unique_ptr temp_exec_state_holder; + GraphExecutionState* execution_state = nullptr; if (options_.config.graph_options().place_pruned_graph()) { // Because we are placing pruned graphs, we need to create a - // new SimpleGraphExecutionState for every new unseen graph, + // new GraphExecutionState for every new unseen graph, // and then place it. - SimpleGraphExecutionStateOptions prune_options; + GraphExecutionStateOptions prune_options; prune_options.device_set = &device_set_; prune_options.session_options = &options_; prune_options.stateful_placements = stateful_placements_; - TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph( + TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph( execution_state_->original_graph_def().library(), prune_options, execution_state_->original_graph_def(), subgraph_options, &temp_exec_state_holder, &client_graph)); diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 020831d6cc..7fbabf6d81 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -28,10 +28,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/session_factory.h" -#include "tensorflow/core/common_runtime/simple_graph_execution_state.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/session_state.h" @@ -310,7 +310,7 @@ class DirectSession : public Session { GUARDED_BY(graph_def_lock_); // Execution_state; used when placing the entire graph. - std::unique_ptr execution_state_ + std::unique_ptr execution_state_ GUARDED_BY(graph_def_lock_); // The function library, before any rewrites or optimizations have been diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc new file mode 100644 index 0000000000..4bd40c7978 --- /dev/null +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -0,0 +1,446 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/graph_execution_state.h" + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/placer.h" +#include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/util.h" + +#ifndef IS_MOBILE_PLATFORM +#include "tensorflow/core/grappler/clusters/utils.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#endif // IS_MOBILE_PLATFORM + +namespace tensorflow { + +GraphExecutionState::GraphExecutionState( + GraphDef* graph_def, const GraphExecutionStateOptions& options) + : stateful_placements_(options.stateful_placements), + device_set_(options.device_set), + session_options_(options.session_options), + flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(), + graph_def->library())), + graph_(nullptr) { + // NOTE(mrry): GraphDef does not have a move constructor, so we pass + // a non-const pointer and use `Swap()` to transfer the contents + // without copying. + original_graph_def_.Swap(graph_def); + // TODO(mrry): Publish placement visualizations or handle the log + // placement option. +} + +GraphExecutionState::~GraphExecutionState() { + node_name_to_cost_id_map_.clear(); + delete graph_; +} + +/* static */ Status GraphExecutionState::MakeForBaseGraph( + GraphDef* graph_def, const GraphExecutionStateOptions& options, + std::unique_ptr* out_state) { + std::unique_ptr ret( + new GraphExecutionState(graph_def, options)); + + TF_RETURN_IF_ERROR( + AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0)); + // TODO(mrry): Refactor InitBaseGraph() so that we don't have to + // pass an empty BuildGraphOptions (that isn't going to be used when + // place_pruned_graph is false). + if (!ret->session_options_->config.graph_options().place_pruned_graph()) { + TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions())); + } + *out_state = std::move(ret); + return Status::OK(); +} + +/* static */ Status GraphExecutionState::MakeForPrunedGraph( + const FunctionDefLibrary& func_def_lib, + const GraphExecutionStateOptions& options, const GraphDef& graph_def, + const BuildGraphOptions& subgraph_options, + std::unique_ptr* out_state, + std::unique_ptr* out_client_graph) { + DCHECK(options.session_options->config.graph_options().place_pruned_graph()); + // NOTE(mrry): This makes a copy of `graph_def`, which is + // regrettable. We could make `GraphDef` objects sharable between + // execution states to optimize pruned graph execution, but since + // this case is primarily used for interactive sessions, we make the + // bet that graph construction is not performance-critical. (Note + // also that the previous version used `Extend()`, which is strictly + // more expensive than copying a `GraphDef`.) + GraphDef temp(graph_def); + std::unique_ptr ret( + new GraphExecutionState(&temp, options)); + TF_RETURN_IF_ERROR( + AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0)); + TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options)); + TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph)); + *out_state = std::move(ret); + return Status::OK(); +} + +Status GraphExecutionState::Extend( + const GraphDef& extension_def, + std::unique_ptr* out) const { + GraphDef gdef; + + // 1. Copy the function library. + TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library())); + *gdef.mutable_library() = flib_def_->ToProto(); + + // 2. Build an index of the new node names. + std::unordered_set new_names; + for (const NodeDef& node : extension_def.node()) { + new_names.insert(node.name()); + } + + // 3. Add the non-duplicates from the old graph to the new graph. + // Return an error if the same node name appears in both the + // old graph and the extension. + for (const NodeDef& node : original_graph_def_.node()) { + if (new_names.count(node.name()) == 0) { + *gdef.add_node() = node; + } else { + return errors::InvalidArgument(tensorflow::strings::Printf( + "GraphDef argument to Extend includes node '%s', which was created " + "by a previous call to Create or Extend in this session.", + node.name().c_str())); + } + } + + // 4. Merge the versions field. + int old_node_size = gdef.node_size(); + gdef.mutable_node()->MergeFrom(extension_def.node()); + TF_RETURN_IF_ERROR( + AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size)); + // Merge versions + if (gdef.has_versions()) { + if (gdef.versions().producer() != extension_def.versions().producer()) { + return errors::InvalidArgument( + "Can't extend GraphDef at version ", gdef.versions().producer(), + " with graph at version ", extension_def.versions().producer()); + } + VersionDef* versions = gdef.mutable_versions(); + versions->set_min_consumer(std::max( + versions->min_consumer(), extension_def.versions().min_consumer())); + if (extension_def.versions().bad_consumers_size()) { + // Add new bad_consumers that aren't already marked bad. + // + // Note: This implementation is quadratic time if there are many calls to + // ExtendLocked with many bad consumers. Since this is unlikely, and + // fixing it would require data structures outside of this routine, + // quadratic time it is. + auto* bad_consumers = versions->mutable_bad_consumers(); + const std::unordered_set existing(bad_consumers->begin(), + bad_consumers->end()); + for (const int v : extension_def.versions().bad_consumers()) { + if (existing.find(v) == existing.end()) { + bad_consumers->Add(v); + } + } + } + + } else { + gdef.mutable_versions()->CopyFrom(extension_def.versions()); + } + + // 5. Validate that the final graphdef is valid. + if (gdef.versions().producer() >= 5) { + // Validate the graph: we assume that merging two valid graphs + // should maintain graph validity. + TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_)); + } + + // 6. Add the extension. + GraphExecutionStateOptions combined_options; + combined_options.device_set = device_set_; + combined_options.session_options = session_options_; + combined_options.stateful_placements = stateful_placements_; + + // NOTE(mrry): `gdef` is no longer valid after the constructor + // executes. + std::unique_ptr new_execution_state( + new GraphExecutionState(&gdef, combined_options)); + + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( + &new_execution_state->original_graph_def_, *flib_def_, 0)); + if (!session_options_->config.graph_options().place_pruned_graph()) { + // TODO(mrry): Refactor InitBaseGraph() so that we don't have to + // pass an empty BuildGraphOptions (that isn't going to be used + // when place_pruned_graph is false). + TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions())); + } + *out = std::move(new_execution_state); + + // TODO(mrry): This is likely to be used for non-throughput-sensitive + // interactive workloads, but in future we may want to transfer other + // parts of the placement and/or cost model. + return Status::OK(); +} + +void GraphExecutionState::SaveStatefulNodes(Graph* graph) { + for (Node* n : graph->nodes()) { + if (n->op_def().is_stateful()) { + VLOG(2) << "Saving " << n->DebugString(); + stateful_placements_[n->name()] = n->assigned_device_name(); + } + } +} + +void GraphExecutionState::RestoreStatefulNodes(Graph* graph) { + for (Node* n : graph->nodes()) { + if (n->op_def().is_stateful()) { + auto iter = stateful_placements_.find(n->name()); + if (iter != stateful_placements_.end()) { + n->set_assigned_device_name(iter->second); + VLOG(2) << "Restored " << n->DebugString(); + } + } + } +} + +Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) { + const GraphDef* graph_def = &original_graph_def_; + + std::unique_ptr new_graph(new Graph(OpRegistry::Global())); + GraphConstructorOptions opts; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get())); + for (const Node* n : new_graph->nodes()) { + VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id(); + node_name_to_cost_id_map_[n->name()] = n->cost_id(); + } + if (session_options_ && + session_options_->config.graph_options().place_pruned_graph()) { + // Rewrite the graph before placement. + rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata); + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( + new_graph.get(), options.feed_endpoints, options.fetch_endpoints, + options.target_nodes, device_set_->client_device()->attributes(), + options.use_function_convention, rewrite_metadata_.get())); + } + + // Save stateful placements before placing. + RestoreStatefulNodes(new_graph.get()); + + GraphOptimizationPassOptions optimization_options; + optimization_options.session_options = session_options_; + optimization_options.graph = &new_graph; + optimization_options.flib_def = flib_def_.get(); + optimization_options.device_set = device_set_; + + TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); + + Placer placer(new_graph.get(), device_set_, session_options_); + // TODO(mrry): Consider making the Placer cancelable. + TF_RETURN_IF_ERROR(placer.Run()); + + TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PLACEMENT, optimization_options)); + + SaveStatefulNodes(new_graph.get()); + graph_ = new_graph.release(); + return Status::OK(); +} + +Status GraphExecutionState::OptimizeGraph( + const BuildGraphOptions& options, std::unique_ptr* optimized_graph) { +#ifndef IS_MOBILE_PLATFORM + if (session_options_->config.graph_options().place_pruned_graph()) { + return errors::InvalidArgument("Can't optimize a pruned graph"); + } + + const RewriterConfig& rewrite_options = + session_options_->config.graph_options().rewrite_options(); + + if (grappler::MetaOptimizerEnabled(rewrite_options)) { + // Adding this functionality in steps. The first step is to make sure + // we don't break dependencies. The second step will be to turn the + // functionality on by default. + grappler::GrapplerItem item; + item.id = "tf_graph"; + graph_->ToGraphDef(&item.graph); + + item.fetch = options.fetch_endpoints; + item.fetch.insert(item.fetch.end(), options.target_nodes.begin(), + options.target_nodes.end()); + + if (!options.feed_endpoints.empty()) { + std::unordered_set feeds; + for (const string& feed : options.feed_endpoints) { + TensorId id = ParseTensorName(feed); + if (id.second != 0) { + return errors::InvalidArgument("Unsupported feed: ", feed); + } + feeds.insert(id.first.ToString()); + } + for (const NodeDef& node : original_graph_def_.node()) { + if (feeds.find(node.name()) == feeds.end()) { + continue; + } + if (node.attr().count("dtype") == 0 || + node.attr().count("shape") == 0) { + return errors::InvalidArgument("Missing node shape or type"); + } + TensorShapeProto shape_proto(node.attr().at("shape").shape()); + // If the shape of the placeholder value is only partially known, we're + // free to use any dimension we want to feed the placeholder. We choose + // 1 to minimize the memory impact. Note that this only matters if an + // optimizer choose to run the graph to build its cost model, which + // doesn't happen (yet) + if (shape_proto.unknown_rank()) { + shape_proto.set_unknown_rank(false); + } + for (auto& dim : *shape_proto.mutable_dim()) { + if (dim.size() < 0) { + dim.set_size(1); + } + } + TensorShape shape(shape_proto); + DataType type = node.attr().at("dtype").type(); + Tensor fake_input(type, shape); + item.feed.emplace_back(node.name(), fake_input); + } + } + + std::unordered_map device_map; + Device* cpu_device = nullptr; + for (const auto& device : device_set_->devices()) { + device_map[device->name()] = + grappler::GetDeviceInfo(device->parsed_name()); + if (device->parsed_name().id == 0 && + StringPiece(device->parsed_name().type) == "CPU" && + device->GetAllocator(AllocatorAttributes()) != nullptr) { + cpu_device = device; + } + } + if (cpu_device == nullptr) { + return errors::Internal( + "Unable to find CPU device needed for constant folding"); + } + grappler::VirtualCluster cluster(device_map); + GraphDef new_graph; + TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer( + item, rewrite_options, cpu_device, &cluster, &new_graph)); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + optimized_graph->reset(new Graph(OpRegistry::Global())); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get())); + // The graph conversion sets the requested device names but not the assigned + // device names. However, since at this point the graph is placed TF expects + // an assigned device name for every node. Therefore we copy the requested + // device into the assigned device field. + for (Node* node : optimized_graph->get()->nodes()) { + node->set_assigned_device_name(node->requested_device()); + } + return Status::OK(); + } else { + return errors::InvalidArgument("Meta Optimizer disabled"); + } +#else + return errors::InvalidArgument("Mobile platforms not supported"); +#endif // IS_MOBILE_PLATFORM +} + +Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, + std::unique_ptr* out) { + VLOG(1) << "BuildGraph"; + if (!graph_) { + // It is only valid to call this method directly when the original graph + // was created with the option `place_pruned_graph == false`. + return errors::Internal( + "Attempted to prune a graph that has not been fully initialized."); + } + + std::unique_ptr ng; + Status s = OptimizeGraph(options, &ng); + if (!s.ok()) { + // Simply copy the original graph if we couldn't optimize it. + ng.reset(new Graph(flib_def_.get())); + CopyGraph(*graph_, ng.get()); + } + + subgraph::RewriteGraphMetadata rewrite_metadata; + if (session_options_ == nullptr || + !session_options_->config.graph_options().place_pruned_graph()) { + // Extract the subset of the graph that needs to be run, adding feed/fetch + // ops as needed. + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( + ng.get(), options.feed_endpoints, options.fetch_endpoints, + options.target_nodes, device_set_->client_device()->attributes(), + options.use_function_convention, &rewrite_metadata)); + } else { + // This GraphExecutionState represents a graph that was + // pruned when this was constructed, so we copy the metadata from + // a member variable. + CHECK(rewrite_metadata_); + rewrite_metadata = *rewrite_metadata_; + } + + CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size()); + CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size()); + + // Make a fresh copy of the function library for the client graph. + std::unique_ptr flib( + new FunctionLibraryDefinition(*flib_def_)); + + // TODO(andydavis): Clarify optimization pass requirements around CostModel. + GraphOptimizationPassOptions optimization_options; + optimization_options.session_options = session_options_; + optimization_options.graph = &ng; + optimization_options.flib_def = flib.get(); + optimization_options.device_set = device_set_; + + TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); + + // Copy the extracted graph in order to make its node ids dense, + // since the local CostModel used to record its stats is sized by + // the largest node id. + std::unique_ptr dense_copy( + new ClientGraph(std::move(flib), rewrite_metadata.feed_types, + rewrite_metadata.fetch_types)); + CopyGraph(*ng, &dense_copy->graph); + + // TODO(vrv): We should check invariants of the graph here. + + *out = std::move(dense_copy); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h new file mode 100644 index 0000000000..db2686ce2c --- /dev/null +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -0,0 +1,208 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/build_graph_options.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +struct SessionOptions; + +namespace subgraph { +struct RewriteGraphMetadata; +} + +struct GraphExecutionStateOptions { + const DeviceSet* device_set = nullptr; + const SessionOptions* session_options = nullptr; + // A map from node name to device name, representing the unchangeable + // placement of stateful nodes. + std::unordered_map stateful_placements; +}; + +// A ClientGraph is simply a sub-graph of the full graph as induced by +// BuildGraphOptions. +struct ClientGraph { + explicit ClientGraph(std::unique_ptr flib, + DataTypeVector feed_types, DataTypeVector fetch_types) + : flib_def(std::move(flib)), + graph(flib_def.get()), + feed_types(std::move(feed_types)), + fetch_types(std::move(fetch_types)) {} + // Each client-graph gets its own function library since optimization passes + // post rewrite for execution might want to introduce new functions. + std::unique_ptr flib_def; + Graph graph; + DataTypeVector feed_types; + DataTypeVector fetch_types; +}; + +// GraphExecutionState is responsible for generating an +// executable ClientGraph from the original GraphDef that specifies +// the complete graph and from BuildGraphOptions which specifies +// input/output nodes. +// +// An executable Graph differs from a GraphDef by being Placed, +// meaning that each Node is assigned to a single Device in the +// available set. +// +// When GraphExecutionState is first constructed it instantiates +// a full Graph from the provided GraphDef, and places it, using only +// the static device assignments from the GraphDef. Nodes without are +// currently placed in a very naive way. Since stateful Nodes cannot +// be moved after initial placement, it is important that stateful +// Nodes get sensible initial device assignments in the graph +// definition. +// +// Subsequently, GraphExecutionState generates a SimpleClientGraph on +// demand, which is a sub-graph of the latest placement of the full +// Graph. MasterSession uses such a ClientGraph to execute one or +// more similar client requests. +// +// GraphExecutionState is thread-safe. + +class GraphExecutionState { + public: + virtual ~GraphExecutionState(); + + // Creates a new `GraphExecutionState` for the given + // `graph_def`, which represents the entire graph for a session. + // + // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def` + // in an undefined state. If it is necessary to use `*graph_def` + // after this call, make an explicit copy of the graph before + // calling this method. + static Status MakeForBaseGraph( + GraphDef* graph_def, const GraphExecutionStateOptions& options, + std::unique_ptr* out_state); + + // Creates a new `GraphExecutionState` and `SimpleClientGraph` + // for the subgraph of `original_graph_def` defined by + // `subgraph_options`. + static Status MakeForPrunedGraph( + const FunctionDefLibrary& func_def_lib, + const GraphExecutionStateOptions& options, + const GraphDef& original_graph_def, + const BuildGraphOptions& subgraph_options, + std::unique_ptr* out_state, + std::unique_ptr* out_client_graph); + + // Creates a new GraphExecutionState representing the + // concatenation of this graph, and the graph defined by + // "extension_def". The same name may not be used to define a node + // in both this graph and "extension_def". + // + // If successful, returns OK and the caller takes ownership of "*out". + // Otherwise returns an error and does not modify "*out". + // + // After calling `old_state->Extend()`, `old_state` may no longer be + // used. + // + // NOTE(mrry): This method respects the placement of stateful nodes in + // in *this, but currently does not transfer any other placement + // or cost model information to the new graph. + Status Extend(const GraphDef& extension_def, + std::unique_ptr* out) const; + + // Builds a ClientGraph (a sub-graph of the full graph as induced by + // the Node set specified in "options"). If successful, returns OK + // and the caller takes the ownership of "*out". Otherwise, returns + // an error. + Status BuildGraph(const BuildGraphOptions& options, + std::unique_ptr* out); + + // The graph returned by BuildGraph may contain only the pruned + // graph, whereas some clients may want access to the full graph. + const Graph* full_graph() { + return graph_; + } + + // Returns the node with the given name, or null if it does not exist. + const Node* get_node_by_name(const string& name) const { + NodeNameToCostIdMap::const_iterator iter = + node_name_to_cost_id_map_.find(name); + if (iter != node_name_to_cost_id_map_.end()) { + return graph_->FindNodeId(iter->second); + } else { + return nullptr; + } + } + + // Returns a reference to the current graph_def. Use must + // not extend beyond lifetime of GrahExecutionState object. + const GraphDef& original_graph_def() { return original_graph_def_; } + + // Returns the map of stateful placements as a map of + // node name to placement string. + std::unordered_map GetStatefulPlacements() const { + return stateful_placements_; + } + + private: + GraphExecutionState(GraphDef* graph_def, + const GraphExecutionStateOptions& options); + + Status InitBaseGraph(const BuildGraphOptions& options); + + // Map of placed stateful nodes, i.e. nodes for which is_stateful() + // is true, such as "params" and "queue" nodes. Once placed these + // nodes can not be moved to a different device. Maps node names to + // device names. + std::unordered_map stateful_placements_; // Immutable after + // ctor. + void SaveStatefulNodes(Graph* graph); + void RestoreStatefulNodes(Graph* graph); + + Status OptimizeGraph(const BuildGraphOptions& options, + std::unique_ptr* optimized_graph); + + GraphDef original_graph_def_; // Immutable after ctor. + const DeviceSet* device_set_; // Not owned + const SessionOptions* session_options_; // Not owned + + // Map from name to Node for the full graph in placed_. + NodeNameToCostIdMap node_name_to_cost_id_map_; + + // 'flib_def_' is initialized from the initial graph def's library, + // and may be updated by a graph optimization pass. + std::unique_ptr flib_def_; + + // `rewrite_metadata_` is only set for GraphExecutionState + // objects created by `MakeForPrunedGraph()`. + std::unique_ptr rewrite_metadata_; + + // The dataflow graph owned by this object. + Graph* graph_; + + TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc new file mode 100644 index 0000000000..73fdf60fd5 --- /dev/null +++ b/tensorflow/core/common_runtime/placer.cc @@ -0,0 +1,880 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/placer.h" + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +namespace { + +// We hoist the conversion from C-style string literal to StringPiece here, +// so that we can avoid the many repeated calls to strlen(). +const StringPiece kColocationAttrNameStringPiece(kColocationAttrName); +const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); + +// Returns a list of devices sorted by preferred type and then name +// from 'devices' whose type is in 'supported_device_types'. This +// function searches the device types in 'supported_device_types' and +// returns the subset of devices that match. +std::vector FilterSupportedDevices( + const std::vector& devices, + const DeviceTypeVector& supported_device_types) { + std::vector filtered_devices; + for (const DeviceType& d : supported_device_types) { + for (Device* device : devices) { + if (DeviceType(device->attributes().device_type()) == d) { + filtered_devices.emplace_back(device); + } + } + } + + auto device_sort = [](const Device* a, const Device* b) { + auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type())); + auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type())); + // First sort by prioritized device type (higher is preferred) and + // then by device name (lexicographically). + if (a_priority != b_priority) { + return a_priority > b_priority; + } + return StringPiece(a->name()) < StringPiece(b->name()); + }; + std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort); + return filtered_devices; +} + +// This class maintains the connected components of a colocation +// constraint graph, and uses this information to assign a satisfying +// device placement to the nodes of the graph. +// +// The typical usage pattern is: +// +// Graph graph = ...; +// DeviceSet device_set = ...; +// ColocationGraph colocation_graph(graph, device_set); +// +// // Add all the nodes of graph to colocation_graph. +// for (Node* node : graph.nodes()) { +// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node)); +// } +// +// // Add one or more colocation constraint. +// Node node_1 = *graph.FindNodeId(...); +// Node node_2 = *graph.FindNodeId(...); +// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2)); +// +// // Assign devices based on the accumulated constraints. +// for (Node* node : graph.nodes()) { +// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node)); +// } +// +// The implementation uses the union-find algorithm to maintain the +// connected components efficiently and incrementally as edges +// (implied by ColocationGraph::ColocateNodes() invocations) are added. +class ColocationGraph { + public: + ColocationGraph(Graph* graph, const DeviceSet* device_set, + bool allow_soft_placement) + : graph_(graph), + device_set_(device_set), + device_types_(device_set->PrioritizedDeviceTypeList()), + allow_soft_placement_(allow_soft_placement) { + members_.resize(graph->num_node_ids()); + } + + // Adds each node of the Graph to this ColocationGraph as a singleton. + // + // NOTE: The implementation assumes that the ids of nodes passed to + // this method are dense and zero-based; the memory used will be linear in + // the largest node ID. + // NOTE: If this method returns an error, *this is left in an undefined + // state. + Status ColocateAllNodes() { + // This maps from a colocation group identifier to the 'root' of that + // colocation group. Note that the keys in this map are StringPiece; the + // actual strings are stored under the NodeDef. The lifetime of this map + // is limited to this ColocateAllNodes() method, and no part of the + // NodeDef trees are changed during the lifetime of this method, so using + // StringPiece as a key is safe. + // + // Also, as a further optimization, we remove the "loc:@" prefix from + // "class" attribute values, when they are used as keys in this table. + // This allows us to use StringPiece values that refer to substrings of + // 'string' values stored in NodeDef attribute lists, as well as StringPiece + // values that refer to 'string' values from NodeDef::name(), without + // performing any string allocations. + std::unordered_map + colocation_group_root; + + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + + // When adding the node, identify whether it is part of a + // colocation group. + + // This code is effectively the equivalent of GetNodeAttr() for a string + // array, but it avoids all internal allocations (the allocation of the + // backing store of the std::vector as well as the copies of the + // strings within it). Instead, we combine the query of the colocation + // attribute with the calls to ColocateNodeToGroup. + bool found_spec = false; + const AttrValue* attr_value = + node->attrs().Find(kColocationAttrNameStringPiece); + if (attr_value != nullptr && attr_value->has_list()) { + for (const string& class_spec : attr_value->list().s()) { + StringPiece spec(class_spec); + if (spec.Consume(kColocationGroupPrefixStringPiece)) { + found_spec = true; + TF_RETURN_IF_ERROR( + ColocateNodeToGroup(&colocation_group_root, node, spec)); + } + } + } + + if (!found_spec) { + // If the node does not specify a colocation group, then use the + // name of this node as the colocation group. + TF_RETURN_IF_ERROR( + ColocateNodeToGroup(&colocation_group_root, node, node->name())); + } + } + + return Status::OK(); + } + + Status ColocateNodeToGroup( + std::unordered_map* + colocation_group_root, + Node* node, StringPiece colocation_group) { + const Node*& root_node = (*colocation_group_root)[colocation_group]; + if (root_node == nullptr) { + // This is the first node of the colocation group, so + // designate this node as the 'root' of that colocation group. + root_node = node; + } else { + // Try to colocate the node with the root. If there is an + // error, return it. + Status s = ColocateNodes(*node, *root_node); + if (!s.ok()) { + return AttachDef(s, *node); + } + } + return Status::OK(); + } + + // Merge the (possibly disjoint) sets containing nodes "x" and + // "y". Returns OK if the all nodes in the union of these sets can + // be placed on the same device type. + // + // NOTE: If this method returns an error, *this is left in an undefined + // state. + Status ColocateNodes(const Node& x, const Node& y) { + int x_root = FindRoot(x.id()); + int y_root = FindRoot(y.id()); + return ColocateNodes(x, x_root, y, y_root); + } + + // This overload of ColocateNodes() allows a caller to provide the root node + // ids for the two nodes. For large graphs, this noticeably reduces the + // graph load time. + Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) { + if (x_root == y_root) { + return Status::OK(); + } + + DCHECK_EQ(x_root, FindRoot(x.id())); + DCHECK_EQ(y_root, FindRoot(y.id())); + + Member& x_root_member = members_[x_root]; + Member& y_root_member = members_[y_root]; + + // Merge the sets by swinging the parent pointer of the smaller + // tree to point to the root of the larger tree. Together with + // path compression in ColocationGraph::FindRoot, this ensures + // that we do not experience pathological performance on graphs + // such as chains. + int new_root, old_root; + if (x_root_member.rank < y_root_member.rank) { + // The tree rooted at x_root is shallower, so connect it to + // y_root. The rank of y_root is unchanged because its new + // child has strictly less rank. + x_root_member.parent = y_root; + new_root = y_root; + old_root = x_root; + } else if (x_root_member.rank > y_root_member.rank) { + // The tree rooted at y_root is shallower, so connect it to + // x_root. The rank of x_root is unchanged because its new + // child has strictly less rank. + y_root_member.parent = x_root; + new_root = x_root; + old_root = y_root; + } else { + // Both trees have the same rank, so break the tie by choosing + // x_root as the new root. + y_root_member.parent = x_root; + // Increment the rank of the tree rooted at x_root, because it + // is now strictly deeper than before. + ++x_root_member.rank; + new_root = x_root; + old_root = y_root; + } + + Member& new_root_member = members_[new_root]; + Member& old_root_member = members_[old_root]; + + // Merge the partial device specifications, and ensure that they are + // compatible. NULL options_ is treated as allowing soft placement. + // TODO(mrry): Consider enriching the error message by pointing + // out which nodes have the explicit partial device + // specifications that caused this conflict. + Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name, + old_root_member.device_name, + allow_soft_placement_); + if (!s.ok()) { + return errors::InvalidArgument("Cannot colocate nodes '", x.name(), + "' and '", y.name(), ": ", + s.error_message()); + } + + // Ensure that the common root has at least one supported device + // type, by computing the intersection of + // new_root_member.supported_device_types and + // old_root_member.supported_device_types. + MergeSupportedDevices(&new_root_member.supported_device_types, + old_root_member.supported_device_types); + if (new_root_member.supported_device_types.empty()) { + return errors::InvalidArgument( + "Cannot colocate nodes '", x.name(), "' and '", y.name(), + "' because no device type supports both of those nodes and the " + "other nodes colocated with them.", + DebugInfo(x_root), DebugInfo(y_root)); + } + + return Status::OK(); + } + + // For the given node, subject to the constraints previously given + // to this ColocationGraph, set its assigned_device_name. Returns OK + // if a satisfying device can be found, otherwise an error. + // + // Note: This method returns a pointer to a field within members_. + // The caller must not use the returned pointer after there is any possibility + // that the members_[i].possible_devices field has been modified. + Status GetDevicesForNode(Node* node, + std::vector** possible_devices) { + *possible_devices = nullptr; + const int node_root = FindRoot(node->id()); + if (!members_[node_root].possible_devices.empty()) { + *possible_devices = &members_[node_root].possible_devices; + return Status::OK(); + } + + // We have not yet computed the possible devices for the + // colocated node set containing 'node', so we do so now using the + // constraints on the root node. + + // "devices" will contain the set of feasible placements for the + // colocated node set containing 'node'. + std::vector devices; + if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) { + // The root node has a (possibly partial) device + // specification, so enumerate the physical devices that + // conform to it. + device_set_->FindMatchingDevices(members_[node_root].device_name, + &devices); + + if (!devices.empty()) { + // Filter devices into those that are compatible with the root + // node (and its children). + devices = FilterSupportedDevices( + devices, members_[node_root].supported_device_types); + } + + // Perform soft placement if allow_soft_placement_ is set. + if (devices.empty() && allow_soft_placement_) { + // The soft_device_name is the same as the node's device name + // without specifying the device type or ID. + DeviceNameUtils::ParsedName soft_device_name = + members_[node_root].device_name; + soft_device_name.type.clear(); + soft_device_name.has_type = false; + soft_device_name.has_id = false; + device_set_->FindMatchingDevices(soft_device_name, &devices); + if (!devices.empty()) { + devices = FilterSupportedDevices( + devices, members_[node_root].supported_device_types); + } + } + + if (devices.empty()) { + // Return an error when a physical device that matches an explicit + // device specification is not found. This ensures that we don't + // assign a node to GPU when the user wanted to force it on CPU. + string debug_info = DebugInfo(node_root); + + DeviceNameUtils::ParsedName specified_device_name; + if (DeviceNameUtils::ParseFullName(node->requested_device(), + &specified_device_name) && + specified_device_name == members_[node_root].device_name) { + // The specified device and merged set device match, and + // will appear in the GraphDef (for debugging), so just + // print the specified device. + std::vector devices_matching_nodedef; + device_set_->FindMatchingDevices(specified_device_name, + &devices_matching_nodedef); + if (devices_matching_nodedef.empty()) { + // Sometimes it is almost impossible to understand the problem + // without a list of available devices. + std::vector device_names; + for (const Device* device : device_set_->devices()) { + device_names.push_back(device->name()); + } + std::sort(device_names.begin(), device_names.end()); + + return errors::InvalidArgument( + "Operation was explicitly assigned to ", + node->requested_device(), " but available devices are [ ", + str_util::Join(device_names, ", "), " ]. Make sure ", + "the device specification refers to a valid device."); + } else if (specified_device_name.has_type) { + return errors::InvalidArgument( + "Could not satisfy explicit device specification '", + node->requested_device(), "' because no supported kernel for ", + specified_device_name.type, " devices is available.", + debug_info); + } else { + return errors::InvalidArgument( + "Could not satisfy explicit device specification '", + node->requested_device(), debug_info); + } + } else { + // The specified device may be a valid device but the + // merged set device is different, so print both. + return errors::InvalidArgument( + "Could not satisfy explicit device specification '", + node->requested_device(), + "' because the node was colocated with a group of nodes that " + "required incompatible device '", + DeviceNameUtils::ParsedNameToString( + members_[node_root].device_name), + "'", debug_info); + } + } + } else { + // The device is completely unspecified, so enumerate the devices that + // support all of the nodes in the set. + if (device_set_->devices().empty()) { + return errors::Internal("No devices are registered"); + } + devices = FilterSupportedDevices( + device_set_->devices(), members_[node_root].supported_device_types); + + if (devices.empty()) { + return errors::InvalidArgument( + "Node had no OpKernel registered to support this operation: ", + "Operation was ", node->type_string(), " and inputs were ", + DataTypeVectorString(node->input_types()), DebugInfo(node_root)); + } + } + + // Cache the result of the possible devices for this node group. + members_[node_root].possible_devices = std::move(devices); + *possible_devices = &members_[node_root].possible_devices; + return Status::OK(); + } + + Status InitializeMembers() { + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + Status status = InitializeMember(*node, &members_[node->id()]); + if (!status.ok()) { + return AttachDef(status, *node); + } + } + return Status::OK(); + } + + // Represents a node in the disjoint node set forest, and the + // accumulated constraints on the device used by that node. + struct Member { + Member() = default; + // The id of the node that is the parent of this one, or its own + // id if it is a root. parent <= 0 indicates that this member is invalid. + int parent = -1; + + // A proxy for the depth of the tree that is used to prefer + // connecting smaller trees to larger trees when merging disjoint + // sets. + int rank = 0; + + // The intersection of all device types supported by this node, + // and those of all of its children, in priority order + // of the preferred device. + DeviceTypeVector supported_device_types; + + // The merged form of the device requested for this node, with + // those of all of its children. + DeviceNameUtils::ParsedName device_name; + + // If this node is a root, stores a list of Devices to which this node + // and all of its children have been assigned, or nullptr if this + // has not yet been computed. + std::vector possible_devices; + }; + + // Returns debugging info for the node referred to by 'node_root'. + string DebugInfo(const int node_root) { + string text( + "\nColocation Debug Info:\n" + "Colocation group had the following types and devices: "); + + // If this node is part of a colocation group, then we want to + // collect the mapping of ops to supported devices, so that + // the user can see why an unsatisfiable placement occurred. + + std::unordered_map type_to_devices; + int num_nodes_found = 0; + + for (const Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + int id = node->id(); + if (FindRoot(id) != node_root) { + continue; + } + ++num_nodes_found; + const string& op_type = node->type_string(); + string devices_registered; + for (const auto& device_type : members_[id].supported_device_types) { + strings::StrAppend(&devices_registered, DeviceTypeString(device_type), + " "); + } + + type_to_devices[op_type] = std::move(devices_registered); + } + + for (const auto& td : type_to_devices) { + strings::StrAppend(&text, "\n", td.first, ": ", td.second); + } + + if (num_nodes_found <= 1) { + text.clear(); + } + return text; + } + + Status InitializeMember(const Node& node, Member* member) { + const int id = node.id(); + DCHECK_GE(id, 0); + member->parent = id; + TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( + device_types_, node.def(), &member->supported_device_types)); + + if (node.has_assigned_device_name()) { + // This node has already been assigned to a device, so we + // respect this placement, after sanity-checking it. The + // device_name and supported_device_types for this node reflect + // the assigned device, so any nodes colocated with this node + // will be assigned to the same device (assuming this is + // possible). + // NOTE: Since any assignment must have been performed by + // the TensorFlow runtime, we consider errors in this branch to + // be INTERNAL. + const string& assigned_device_name = node.assigned_device_name(); + if (!DeviceNameUtils::ParseFullName(assigned_device_name, + &member->device_name)) { + return errors::Internal("Malformed assigned device '", + assigned_device_name, "'"); + } + const Device* assigned_device = + device_set_->FindDeviceByName(assigned_device_name); + if (assigned_device == nullptr) { + return errors::Internal("Assigned device '", assigned_device_name, + "' does not match any device"); + } + + for (const DeviceType& d : member->supported_device_types) { + if (DeviceType(assigned_device->attributes().device_type()) == d) { + return Status::OK(); + } + } + + return errors::Internal("Assigned device '", assigned_device_name, + "' does not have registered OpKernel support " + "for ", + node.type_string()); + } else { + // This node has not yet been assigned to a device, so we + // calculate any constraints due to the set of registered + // kernels and any (partial) user-provided device specification + // in the NodeDef. + + // If no kernels are registered for this op type, fail with an error. + if (member->supported_device_types.empty()) { + std::set registered_device_types; + for (Device* d : device_set_->devices()) { + registered_device_types.insert(d->device_type()); + } + return errors::InvalidArgument( + "No OpKernel was registered to support Op '", node.type_string(), + "' with these attrs. Registered devices: [", + str_util::Join(registered_device_types, ","), + "], Registered kernels:\n", + KernelsRegisteredForOp(node.type_string())); + } + + // If the NodeDef contains a device, then we interpret it as a + // (partial) device specification. + if (!node.requested_device().empty()) { + // The user has specified a device in the NodeDef, try to find a + // valid device matching their specification in the set of + // devices. + // NOTE: The full name may specify a device that is not in + // n.supported_device_types(), but we check that in AssignDevice(). + if (!DeviceNameUtils::ParseFullName(node.requested_device(), + &member->device_name)) { + return errors::InvalidArgument("Malformed device specification '", + node.requested_device(), "'"); + } + } + } + return Status::OK(); + } + + // Updates target to contain the intersection of the device types in + // "target" and "other". + static void MergeSupportedDevices(DeviceTypeVector* target, + const DeviceTypeVector& other) { + DeviceTypeVector temp = *target; + target->clear(); + + // Iterate in priority order. + for (const DeviceType& device_type : temp) { + bool found = false; + for (const DeviceType& other_device_type : other) { + if (device_type == other_device_type) { + found = true; + break; + } + } + if (found) { + target->push_back(device_type); + } + } + } + + // Returns the root node of the disjoint tree to which the node with the + // given id is connected. + int FindRoot(int node_id) { + Member& member = members_[node_id]; + + int parent = member.parent; + DCHECK_GE(parent, 0); + + if (parent != node_id) { + // NOTE: Compress paths from node_id to its root, so that future + // calls to FindRoot and ColocateNodes are more efficient. + int root = FindRoot(parent); + if (parent != root) { + parent = root; + member.parent = root; + } + } + + DCHECK_GE(parent, 0); + return parent; + } + + Graph* const graph_; // Not owned. + std::vector members_; + const DeviceSet* device_set_; // Not owned. + const std::vector device_types_; + const bool allow_soft_placement_; +}; + +// Returns true if the node has no inputs and produces outputs +// that are consumed by a single node. +// +// TODO(vrv): Currently this handles only nodes with one output, but +// this could be extended to handle the case where a node has many +// outputs that are connected to nodes in the same colocation group. +bool IsGeneratorNode(const Node* node) { + return node->num_inputs() == 0 && node->num_outputs() == 1 && + !IsRefType(node->output_type(0)); +} + +} // namespace + +Placer::Placer(Graph* graph, const DeviceSet* devices, + const SessionOptions* options) + : graph_(graph), + devices_(devices), + options_(options), + log_device_placement_(options != nullptr && + options->config.log_device_placement()) {} + +Placer::Placer(Graph* graph, const DeviceSet* devices) + : Placer(graph, devices, nullptr) {} + +Placer::~Placer() {} + +Status Placer::Run() { + if (devices_->devices().empty()) { + return errors::FailedPrecondition("No devices are registered"); + } + + ColocationGraph colocation_graph( + graph_, devices_, + options_ == nullptr || options_->config.allow_soft_placement()); + + TF_RETURN_IF_ERROR(colocation_graph.InitializeMembers()); + + // 1. First add all of the nodes. Note that steps (1) and (2) + // requires two passes over the nodes because the graph (and hence + // the constraints) may not be acyclic. + TF_RETURN_IF_ERROR(colocation_graph.ColocateAllNodes()); + + // 2. Enumerate the constraint edges, and use them to update the disjoint + // node set. + + // If `node` has an input edge with reference type, add an + // edge from the source of that edge to `node`. + for (const Edge* edge : graph_->edges()) { + if (edge->IsControlEdge()) { + continue; + } + Node* src = edge->src(); + Node* dst = edge->dst(); + DataType input_type = dst->input_type(edge->dst_input()); + if (input_type == DT_RESOURCE || IsRefType(input_type)) { + int src_root_id = colocation_graph.FindRoot(src->id()); + int dst_root_id = colocation_graph.FindRoot(dst->id()); + auto& src_root = colocation_graph.members_[src_root_id]; + auto& dst_root = colocation_graph.members_[dst_root_id]; + // If both the source node and this node have partially + // specified a device, then 'node's device should be + // cleared: the reference edge forces 'node' to be on the + // same device as the source node. + const auto& source_parsed_name = src_root.device_name; + const auto& dest_parsed_name = dst_root.device_name; + if (DeviceNameUtils::HasSomeDetails(source_parsed_name) && + DeviceNameUtils::HasSomeDetails(dest_parsed_name)) { + // Ignore a specified device for 'dst' if the two names were + // incompatible. + if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, + dest_parsed_name)) { + if (log_device_placement_) { + LOG(INFO) << "Ignoring device specification " + << DeviceNameUtils::ParsedNameToString(dest_parsed_name) + << " for node '" << dst->name() + << "' because the input edge from '" << src->name() + << "' is a reference connection and already has a device " + "field set to " + << DeviceNameUtils::ParsedNameToString( + source_parsed_name); + } + + // Make 'dst' colocated with the source + dst_root.device_name = source_parsed_name; + } else { + bool source_subset_of_dest = DeviceNameUtils::IsSpecification( + source_parsed_name, dest_parsed_name); + bool dest_subset_of_source = DeviceNameUtils::IsSpecification( + dest_parsed_name, source_parsed_name); + + if (source_subset_of_dest && !dest_subset_of_source) { + src_root.device_name = dest_parsed_name; + } else { + dst_root.device_name = source_parsed_name; + } + } + } + + Status status = + colocation_graph.ColocateNodes(*src, src_root_id, *dst, dst_root_id); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Nodes were connected by a " + "reference connection (requiring them to " + "be on the same device), but the two nodes " + "were assigned two different devices: ", + status.error_message()), + *dst); + } + } + } + + // 3. For each node, assign a device based on the constraints in the + // disjoint node set. + std::vector second_pass; + for (Node* node : graph_->op_nodes()) { + // The graph may have come pre-populated by the framework with assigned + // devices (e.g., for stateful placements), so the placer should not try to + // place nodes that are already placed. + if (node->has_assigned_device_name()) { + LogDeviceAssignment(node); + continue; + } + + // Heuristic A: prefer to place "generators" with their only + // consumers. + // + // If this is a node with no inputs and one output, we save + // this for a second pass, so that the consumer's placement + // is chosen. + if (IsGeneratorNode(node)) { + second_pass.push_back(node); + continue; + } + + std::vector* devices; + Status status = colocation_graph.GetDevicesForNode(node, &devices); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation '", + node->name(), "': ", status.error_message()), + *node); + } + + // Returns the first device in sorted devices list so we will always + // choose the same device. + // + // TODO(vrv): Factor this assignment out into a pluggable + // algorithm, so that Placer is responsible for enforcing + // preconditions and we can experiment with other algorithms when + // given a choice of devices. Once we have a better idea of the + // types of heuristics we want to use and the information needed + // to perform good placement we can add an interface for this. + int assigned_device = -1; + + // Heuristic B: If the node only operates on metadata, not data, + // then it is desirable to place that metadata node with its + // input. + if (IsMetadata(node)) { + // Make sure that the input device type is in the list of supported + // device types for this node. + const Node* input = (*node->in_edges().begin())->src(); + // TODO(vrv): if the input is empty, consider postponing this + // node's assignment to the second pass, so that we handle the + // case where a metadata node's input comes from a backedge + // of a loop. + if (CanAssignToDevice(input->assigned_device_name(), *devices)) { + assigned_device = input->assigned_device_name_index(); + } + } + + // Provide the default, if necessary. + if (assigned_device == -1) { + assigned_device = graph_->InternDeviceName((*devices)[0]->name()); + } + + AssignAndLog(assigned_device, node); + } + + // 4. Perform a second pass assignment for those nodes explicitly + // skipped during the first pass. + for (Node* node : second_pass) { + std::vector* devices; + Status status = colocation_graph.GetDevicesForNode(node, &devices); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation '", + node->name(), "': ", status.error_message()), + *node); + } + + int assigned_device = -1; + + // Heuristic A application. + if (IsGeneratorNode(node)) { + const Node* output = (*node->out_edges().begin())->dst(); + int output_device_name = output->assigned_device_name_index(); + + const bool consumers_on_same_device = std::all_of( + node->out_edges().begin(), node->out_edges().end(), + [output_device_name](const Edge* e) { + return e->dst()->assigned_device_name_index() == output_device_name; + }); + + if (consumers_on_same_device && + CanAssignToDevice(output->assigned_device_name(), *devices)) { + assigned_device = output_device_name; + } + } + + // Provide the default, if necessary. + if (assigned_device == -1) { + assigned_device = graph_->InternDeviceName((*devices)[0]->name()); + } + + AssignAndLog(assigned_device, node); + } + + return Status::OK(); +} + +bool Placer::CanAssignToDevice(const string& candidate_device_name, + const std::vector& devices) const { + if (!candidate_device_name.empty()) { + // 'devices' lists the set of devices that the placer or the user has + // constrained the operation to. "candidate_device_name" must + // refer to a concrete Device that is in the list of 'devices'. + const Device* other_device = + devices_->FindDeviceByName(candidate_device_name); + if (std::find(devices.begin(), devices.end(), other_device) != + devices.end()) { + return true; + } + } + + return false; +} + +void Placer::AssignAndLog(int assigned_device, Node* node) const { + node->set_assigned_device_name_index(assigned_device); + LogDeviceAssignment(node); +} + +void Placer::LogDeviceAssignment(const Node* node) const { + // Log placement if log_device_placement is set. + if (log_device_placement_) { + printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(), + node->assigned_device_name().c_str()); + LOG(INFO) << node->name() << ": " + << "(" << node->type_string() << ")" + << node->assigned_device_name(); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h new file mode 100644 index 0000000000..c5b76592e1 --- /dev/null +++ b/tensorflow/core/common_runtime/placer.h @@ -0,0 +1,101 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMMON_RUNTIME_PLACER_H_ +#define TENSORFLOW_COMMON_RUNTIME_PLACER_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// A placement algorithm that assigns the nodes of the given Graph to +// devices the given DeviceSet, respecting the following constraints: +// +// 1. Existing device assignments remain unchanged. +// 2. Requested (partial or complete) device specifications given by device name +// for each node are granted. +// 3. Nodes connected by edges of a reference type are colocated on +// the same device. +// 4. Given nodes "A" and "B", if node "B" has a colocation group +// "@loc:A", nodes "A" and "B" will be colocated on the same device. +// +// The implementation builds a constraint graph with the same set of +// nodes, and edges that represent colocation constraints between +// nodes. Each connected component in the resulting constraint graph +// is then assigned to a set of valid devices. +// +// Run() will finally assign the device to each node given the list of +// possible devices. +// +// TODO(mrry): "Soft" constraints, such as "place node 'x' as close as +// possible to node 'y' while respecting the other constraints"? +// TODO(mrry): Create a common interface for this and the other +// placement algorithms so that they may be injected into the graph +// builder. +class Placer { + public: + // A map from graph node names to numerical IDs (in a Graph object). + typedef std::unordered_map NodeNameToIdMap; + + // Creates an instance of the Placer algorithm for the given + // Graph "graph" (nodes in which may or may not be assigned) on the + // given DeviceSet "devices". + // + // The "graph", and "devices" pointer arguments + // are borrowed by this Placer, and must outlive it. + Placer(Graph* graph, const DeviceSet* devices, const SessionOptions* options); + + Placer(Graph* graph, const DeviceSet* devices); + + ~Placer(); + + // Assigns each node in this Placer's graph to a device in its + // set of devices. + // + // This method is not thread-safe. + // Run() may be invoked at most once. + Status Run(); + + private: + // Returns true if the device type of 'candidate_device_name' is + // found in 'devices'. + bool CanAssignToDevice(const string& candidate_device_name, + const std::vector& devices) const; + + // Assigns 'node's devices to 'assigned_device', and logs the + // placement if the SessionOptions entry in 'options_' requests it. + void AssignAndLog(int assigned_device, Node* node) const; + void LogDeviceAssignment(const Node* node) const; + + Graph* const graph_; // Not owned. + const DeviceSet* const devices_; // Not owned. + const SessionOptions* options_; // Not owned. + const bool log_device_placement_; + + TF_DISALLOW_COPY_AND_ASSIGN(Placer); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_PLACER_H_ diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc new file mode 100644 index 0000000000..5d87b1e279 --- /dev/null +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -0,0 +1,1283 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/placer.h" + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +//////////////////////////////////////////////////////////////////////////////// +// +// Op, kernel, and device registrations to set up the environment. +// +// The Placer uses information about the op (input types), +// kernel (device constraints), and available devices to make +// placement decisions. To avoid depending on the full runtime, we +// define dummy implementations of these, and register them with the +// runtime. +// +//////////////////////////////////////////////////////////////////////////////// + +// A dummy OpKernel that is used to register ops on different devices. +class DummyOp : public OpKernel { + public: + explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +// A fake device that has specific device attributes, used to simulate +// the presence of a CPU or a GPU (without depending on that part of +// the runtime. +class FakeDevice : public Device { + private: + explicit FakeDevice(const DeviceAttributes& device_attributes) + : Device(nullptr, device_attributes) {} + + public: + Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + + Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } + + static std::unique_ptr MakeCPU(const string& name) { + DeviceAttributes device_attributes; + device_attributes.set_name(name); + device_attributes.set_device_type(DeviceType("FakeCPU").type()); + return std::unique_ptr(new FakeDevice(device_attributes)); + } + + static std::unique_ptr MakeGPU(const string& name) { + DeviceAttributes device_attributes; + device_attributes.set_name(name); + device_attributes.set_device_type(DeviceType("FakeGPU").type()); + return std::unique_ptr(new FakeDevice(device_attributes)); + } +}; + +class DummyFactory : public DeviceFactory { + public: + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector* devices) override { + return Status::OK(); + } +}; + +// Device order now depends on the registration of devices, not a fixed +// value in device_set.cc. To avoid the need to link in the real CPU and GPU +// devices into this test, we create fake devices and registrations that +// can stand-in for the real devices for the purposes of testing placement +// and ordering. +REGISTER_LOCAL_DEVICE_FACTORY("FakeCPU", DummyFactory); +REGISTER_LOCAL_DEVICE_FACTORY("FakeGPU", DummyFactory, 51); + +// Register the following ops so they can be added to a Graph, and +// kernels so that they can be placed on particular device types. +REGISTER_OP("TestVariable").Output("o: Ref(float)"); +REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeGPU"), DummyOp); + +REGISTER_OP("VariableCPU").Output("o: Ref(float)"); +REGISTER_KERNEL_BUILDER(Name("VariableCPU").Device("FakeCPU"), DummyOp); + +REGISTER_OP("VariableGPU").Output("o: Ref(float)"); +REGISTER_KERNEL_BUILDER(Name("VariableGPU").Device("FakeGPU"), DummyOp); + +REGISTER_OP("VariableNoKernels").Output("o: Ref(float)"); + +REGISTER_OP("TestAdd").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeGPU"), DummyOp); + +REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeGPU"), DummyOp); + +REGISTER_OP("ReluCPU").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("ReluCPU").Device("FakeCPU"), DummyOp); + +REGISTER_OP("ReluGPU").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("ReluGPU").Device("FakeGPU"), DummyOp); + +REGISTER_OP("TestAssign").Input("i: Ref(float)").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeGPU"), DummyOp); + +REGISTER_OP("AssignCPU").Input("i: Ref(float)").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("AssignCPU").Device("FakeCPU"), DummyOp); + +REGISTER_OP("AssignGPU").Input("i: Ref(float)").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("AssignGPU").Device("FakeGPU"), DummyOp); + +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestInput").Device("FakeCPU"), DummyOp); + +// Op producing an output that can be placed on CPU or GPU. +REGISTER_OP("TestCPUGPUOutput").Output("a: float"); +REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeGPU"), DummyOp); + +REGISTER_OP("TestGPUOutput").Output("a: float"); +REGISTER_KERNEL_BUILDER(Name("TestGPUOutput").Device("FakeGPU"), DummyOp); + +REGISTER_OP("TestDevice").Output("a: float").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestDevice").Device("FakeGPU"), DummyOp); + +REGISTER_OP("TestDeviceEnforce").Input("a: Ref(float)").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp); + +REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp); + +//////////////////////////////////////////////////////////////////////////////// +// +// A PlacerTest method has three phases: +// +// 1. Build a TensorFlow graph, with no (or partial) device assignments. +// 2. Attempt to compute a placement using the Placer. +// 3. EITHER: test that the constraints implied by the graph are respected; +// or that an appropriate error was reported. +// +//////////////////////////////////////////////////////////////////////////////// +class PlacerTest : public ::testing::Test { + protected: + PlacerTest() { + // Build a set of 10 GPU and 10 CPU devices. + // NOTE: this->local_devices_ owns the device objects; + // this->devices_ contains borrowed pointers to the device + // objects. + for (int i = 0; i < 10; ++i) { + local_devices_.emplace_back(FakeDevice::MakeCPU( + strings::StrCat("/job:a/replica:0/task:0/device:fakecpu:", i))); + devices_.AddDevice(local_devices_.back().get()); + // Insert the GPUs in reverse order. + local_devices_.emplace_back(FakeDevice::MakeGPU( + strings::StrCat("/job:a/replica:0/task:0/device:fakegpu:", 9 - i))); + devices_.AddDevice(local_devices_.back().get()); + } + } + + // Builds the given graph, and (if successful) indexes the node + // names for use in placement, and later lookup. + Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { + TF_RETURN_IF_ERROR(builder.ToGraph(out_graph)); + nodes_by_name_.clear(); + for (Node* node : out_graph->nodes()) { + nodes_by_name_[node->name()] = node->id(); + } + return Status::OK(); + } + + // Invokes the Placer on "graph". If no DeviceSet is specified, the + // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices). + // + // REQUIRES: "*graph" was produced by the most recent call to BuildGraph. + Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) { + Placer placer(graph, devices, options); + return placer.Run(); + } + + Status Place(Graph* graph, DeviceSet* devices) { + return Place(graph, devices, nullptr); + } + + Status Place(Graph* graph, SessionOptions* options) { + return Place(graph, &devices_, options); + } + + Status Place(Graph* graph) { return Place(graph, &devices_, nullptr); } + + // Returns the node in "graph" with the given name. + // + // REQUIRES: "graph" was produced by the most recent call to BuildGraph. + Node* GetNodeByName(const Graph& graph, const string& name) { + const auto search = nodes_by_name_.find(name); + CHECK(search != nodes_by_name_.end()) << "Unknown node name: " << name; + return graph.FindNodeId(search->second); + } + + protected: + std::vector> local_devices_; + DeviceSet devices_; + Placer::NodeNameToIdMap nodes_by_name_; + + Status ReferenceTestHelper(const string& variable_op_type, + const string& assign_op_type, + const DeviceType& expected_device_type); +}; + +#define EXPECT_COLOCATED(g, name_a, name_b) \ + do { \ + Graph& g_ = (g); \ + EXPECT_EQ(GetNodeByName(g_, (name_a))->assigned_device_name(), \ + GetNodeByName(g_, (name_b))->assigned_device_name()); \ + } while (0) + +#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \ + do { \ + Graph& g_ = (g); \ + EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \ + GetNodeByName(g_, (name_b))->assigned_device_name()); \ + } while (0) + +#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \ + EXPECT_EQ(DeviceType(expected_device_type).type(), \ + devices_ \ + .FindDeviceByName( \ + GetNodeByName((g), (name))->assigned_device_name()) \ + ->attributes() \ + .device_type()) + +#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \ + EXPECT_TRUE(StringPiece(GetNodeByName((g), (name))->assigned_device_name()) \ + .contains(device_substr)) + +// Test that a graph with no constraints will successfully assign nodes to the +// "best available" device (i.e. prefer GPU over CPU). +TEST_F(PlacerTest, TestNoConstraints) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + ops::UnaryOp("TestRelu", ops::NodeOut(input, 0), b.opts().WithName("n1")); + ops::UnaryOp("TestRelu", ops::NodeOut(input, 1), b.opts().WithName("n2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU"); +} + +// Test that a graph with device type and reference constraints on +// some of the ops will successfully assign nodes to the constrained +// device, and colocate nodes with reference connections. +TEST_F(PlacerTest, TestDeviceTypeConstraints) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + ops::BinaryOp("AssignCPU", var_cpu, input, b.opts().WithName("assign_cpu")); + Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu")); + ops::BinaryOp("AssignGPU", var_gpu, input, b.opts().WithName("assign_gpu")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "assign_cpu", "FakeCPU"); + EXPECT_COLOCATED(g, "var_cpu", "assign_cpu"); + EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "assign_gpu", "FakeGPU"); + EXPECT_COLOCATED(g, "var_gpu", "assign_gpu"); +} + +TEST_F(PlacerTest, TestMetadataColocatedWithInput) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + + // Normally, shape has a GPU implementation and would be placed + // on GPU. However, because it is a metadata operation, it is + // placed on CPU to avoid transferring the data from CPU to GPU. + ops::UnaryOp("Shape", var_cpu, b.opts().WithName("shape_op")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "shape_op", "FakeCPU"); + EXPECT_COLOCATED(g, "var_cpu", "shape_op"); +} + +// Heuristic A implements "Island fusing": if a node only generates +// an output and it has only one consumer, we place the node +// with its consumer. +TEST_F(PlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + + // A variable is only on CPU + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + + // The constant to be assigned can be on both GPU or CPU. + // + // Because of the heuristic, it gets placed on CPU to avoid a + // copy. + Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in")); + + // The assign is bound to CPU by the reference edge. + ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign")); + + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "var_cpu", "in"); + EXPECT_COLOCATED(g, "assign", "in"); +} + +TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + + // A variable is only on CPU + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + + // The constant to be assigned can only be on GPU. + // + // The heuristic to place the generator with its consumer does + // not apply since the consumer's device is not in the list + // of valid devices for the generator. + Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in")); + + // The assign is bound to CPU by the reference edge. + ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign")); + + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); + EXPECT_COLOCATED(g, "var_cpu", "assign"); +} + +TEST_F(PlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + + // A variable is only on CPU + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + + // The constant to be assigned can be on CPU or GPU, but is explicitly + // placed on CPU:1. + // + // The heuristic to place the generator with its consumer does + // not apply since the consumer's device is not in the list + // of valid devices for the generator. + Node* input = + ops::SourceOp("TestCPUGPUOutput", + b.opts().WithName("in").WithDevice("/device:fakecpu:1")); + + // The assign is bound to CPU by the reference edge. + ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign")); + + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); + EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakecpu:1"); + EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); + EXPECT_COLOCATED(g, "var_cpu", "assign"); + EXPECT_DEVICE_CONTAINS(g, "var_cpu", "/device:fakecpu:0"); +} + +// Test that a graph with partial device specifications on the ops +// will successfully +TEST_F(PlacerTest, TestPartialSpec) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:a")); + ops::SourceOp("TestVariable", + b.opts().WithName("var").WithDevice("/job:a")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); + EXPECT_DEVICE_CONTAINS(g, "in", "/job:a"); + EXPECT_DEVICE_TYPE(g, "var", "FakeGPU"); + EXPECT_DEVICE_CONTAINS(g, "var", "/job:a"); +} + +// Test that a node with a pre-assigned device is not relocated. +TEST_F(PlacerTest, TestAssignedDevicePreserved) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name( + "/job:a/replica:0/task:0/device:fakecpu:7"); + + TF_EXPECT_OK(Place(&g)); + EXPECT_EQ("/job:a/replica:0/task:0/device:fakecpu:7", + GetNodeByName(g, "in")->assigned_device_name()); +} + +// Test that a graph with partial device specifications for CPU-only ops +// will be relocated to CPU. +TEST_F(PlacerTest, TestPartialSpecGpuToCpu) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", + b.opts().WithName("in").WithDevice("/device:fakegpu:0")); + ops::SourceOp("TestVariable", + b.opts().WithName("var").WithDevice("/device:fakegpu:0")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + TF_EXPECT_OK(Place(&g, &options)); + EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); + EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakecpu"); + EXPECT_DEVICE_TYPE(g, "var", "FakeGPU"); + EXPECT_DEVICE_CONTAINS(g, "var", "/device:fakegpu:0"); +} + +// Test that a node with an assigned GPU device but has not registered +// OpKernel will fail. +TEST_F(PlacerTest, TestAssignedGpuDeviceToCpuDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name( + "/job:a/replica:0/task:0/device:fakegpu:0"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains( + "Assigned device '/job:a/replica:0/task:0/device:fakegpu:0' " + "does not have registered OpKernel support for TestInput")); +} + +// Test that graphs with reference connections are correctly placed. + +// Build a graph containing a Variable op of "variable_op_type" and an +// Assign op of "assign_op_type", and expect all of the ops to be +// placed on a device of type "expected_device_type". +Status PlacerTest::ReferenceTestHelper(const string& variable_op_type, + const string& assign_op_type, + const DeviceType& expected_device_type) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + // Build ten variable-and-assignment pairs. + for (int i = 0; i < 10; ++i) { + Node* var = ops::SourceOp(variable_op_type, + b.opts().WithName(strings::StrCat("var_", i))); + ops::BinaryOp(assign_op_type, var, input, + b.opts().WithName(strings::StrCat("assign_", i))); + } + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_RETURN_IF_ERROR(Place(&g)); + + for (int i = 0; i < 10; ++i) { + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("assign_", i)); + EXPECT_DEVICE_TYPE(g, strings::StrCat("var_", i), expected_device_type); + EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type); + } + + return Status::OK(); +} + +// Test all 2^3 combinations of Variable and Assignment op types +// (unconstrained, CPU-only, and GPU-only). +TEST_F(PlacerTest, TestReferenceConnection) { + Status s; + TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", "FakeGPU")); + TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", "FakeCPU")); + TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignGPU", "FakeGPU")); + TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "TestAssign", "FakeCPU")); + TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "AssignCPU", "FakeCPU")); + { + Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", "FakeCPU"); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("no device type supports both of those nodes")); + } + TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", "FakeGPU")); + { + Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", "FakeCPU"); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("no device type supports both of those nodes")); + } + TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", "FakeGPU")); +} + +// Handle-using dummy variable ops. +REGISTER_OP("TestHandleVariable").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeGPU"), DummyOp); + +REGISTER_OP("HandleVariableCPU").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("HandleVariableCPU").Device("FakeCPU"), DummyOp); + +REGISTER_OP("HandleVariableGPU").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("HandleVariableGPU").Device("FakeGPU"), DummyOp); + +REGISTER_OP("TestHandleAssign").Input("i: resource").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeGPU"), DummyOp); + +REGISTER_OP("HandleAssignCPU").Input("i: resource").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device("FakeCPU"), DummyOp); + +REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float"); +REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp); + +// Tests all combinations of resource handles and ops using them. +TEST_F(PlacerTest, TestResourceHandle) { + auto handle_test = [this](const string& var_op_name, + const string& use_op_name, DeviceType device) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var = ops::SourceOp(var_op_name, b.opts().WithName("var")); + ops::BinaryOp(use_op_name, var, input, b.opts().WithName("assign")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_RETURN_IF_ERROR(Place(&g)); + + EXPECT_COLOCATED(g, "var", "assign"); + EXPECT_DEVICE_TYPE(g, "var", device); + EXPECT_DEVICE_TYPE(g, "assign", device); + return Status::OK(); + }; + TF_EXPECT_OK( + handle_test("TestHandleVariable", "TestHandleAssign", "FakeGPU")); + TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignCPU", "FakeCPU")); + TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignGPU", "FakeGPU")); + TF_EXPECT_OK(handle_test("HandleVariableCPU", "TestHandleAssign", "FakeCPU")); + TF_EXPECT_OK(handle_test("HandleVariableCPU", "HandleAssignCPU", "FakeCPU")); + TF_EXPECT_OK(handle_test("HandleVariableGPU", "HandleAssignGPU", "FakeGPU")); + TF_EXPECT_OK(handle_test("HandleVariableGPU", "TestHandleAssign", "FakeGPU")); + EXPECT_FALSE( + handle_test("HandleVariableGPU", "HandleAssignCPU", "FakeCPU").ok()); + EXPECT_FALSE( + handle_test("HandleVariableCPU", "HandleAssignGPU", "FakeCPU").ok()); +} + +// Test that an assignment of an operator to the wrong device +// is ignored when it could never be satisfied (due to reference +// edges, for example). +TEST_F(PlacerTest, TestReferenceConnectionIgnoreInfeasible) { + Status s; + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp( + "TestDevice", + b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); + Node* var = ops::SourceOp("TestVariable", + b.opts().WithName("var_0").WithDevice( + "/job:a/task:0/device:fakegpu:0")); + + // This op is specified on CPU, but in practice will be ignored, + // because the reference edges forces it on GPU. + ops::BinaryOp("TestAssign", var, input, + b.opts().WithName("assign").WithDevice( + "/job:a/task:0/device:fakecpu:0")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + s = Place(&g, &options); + TF_EXPECT_OK(s); + EXPECT_DEVICE_TYPE(g, "var_0", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "assign", "FakeGPU"); +} + +// Test that an assignment of an operator to the a more specified device +// causes the device to maintain its more specific placement. +TEST_F(PlacerTest, TestReferenceConnectionMoreSpecificDestinationSourceWins) { + Status s; + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + // Input can be on either device + Node* input = + ops::SourceOp("TestCPUGPUOutput", + b.opts().WithName("in").WithDevice("/job:a/task:0")); + + // Variable can be on either device + Node* var = ops::SourceOp( + "TestVariable", b.opts().WithName("var_0").WithDevice("/job:a/task:0")); + + // This op is specified on CPU and is more specific than the variable. + // Because the variable is less specified, the variable will be + // assigned to CPU. + ops::BinaryOp("TestAssign", var, input, + b.opts().WithName("assign").WithDevice( + "/job:a/task:0/device:fakecpu:0")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + s = Place(&g, &options); + TF_EXPECT_OK(s); + EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU"); +} + +// A reference connection exists between a variable and an assign, +// where the assign has a device but the variable does not. In this +// case, the variable gets placed on the location of the assign +// operation. +TEST_F(PlacerTest, TestReferenceConnectionNoSourceDevice) { + Status s; + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp( + "TestDevice", + b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); + Node* var = ops::SourceOp("TestVariable", b.opts().WithName("var_0")); + ops::BinaryOp("TestAssign", var, input, + b.opts().WithName("assign").WithDevice( + "/job:a/task:0/device:fakecpu:0")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + s = Place(&g, &options); + TF_EXPECT_OK(s); + EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU"); +} + +TEST_F(PlacerTest, TestColocationGroup) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* colocated_with_input = ops::UnaryOp( + "TestRelu", input, + b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"})); + + // This will not be colocated with the input because TestInput is + // only availbale on CPU and TestRelu will default to GPU. + Node* not_colocated_with_input = + ops::UnaryOp("TestRelu", input, b.opts().WithName("foo")); + CHECK(colocated_with_input); + CHECK(not_colocated_with_input); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "in", "colocated_1"); + EXPECT_NOT_COLOCATED(g, "in", "foo"); +} + +TEST_F(PlacerTest, TestMultipleColocationGroups) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* colocated_with_input = ops::UnaryOp( + "TestRelu", input, + b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"})); + Node* colocated_with_input_and_other = + ops::UnaryOp("TestRelu", input, + b.opts().WithName("foo").WithAttr( + "_class", {"loc:@in", "loc:@colocated_1"})); + CHECK(colocated_with_input); + CHECK(colocated_with_input_and_other); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "in", "colocated_1"); + EXPECT_COLOCATED(g, "in", "foo"); +} + +TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* colocated_with_input = ops::UnaryOp( + "ReluCPU", input, + b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"})); + Node* colocated_with_input_and_other = + ops::UnaryOp("ReluGPU", input, + b.opts().WithName("foo").WithAttr( + "_class", {"loc:@in", "loc:@colocated_1"})); + CHECK(colocated_with_input); + CHECK(colocated_with_input_and_other); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Cannot colocate nodes 'foo' and 'in' because no " + "device type supports both of those nodes and the " + "other nodes colocated with them")); +} + +TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); + Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); + + // Two assigns (reference connections) with two different + // colocation groups. Because their colocation groups all map to the + // same device, this is a valid assignment. + ops::BinaryOp( + "TestAssign", var1, input, + b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"})); + ops::BinaryOp( + "TestAssign", var2, input, + b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"})); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "in", "var1"); + EXPECT_COLOCATED(g, "in", "var2"); + EXPECT_COLOCATED(g, "var1", "assign2"); + EXPECT_COLOCATED(g, "var2", "assign1"); +} + +TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + + Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); + Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); + // Var 3 is on GPU + Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3")); + + // Two assigns (reference connections) with two different + // colocation groups. Because their colocation groups all map to the + // same device, this is a valid assignment. + ops::BinaryOp( + "TestAssign", var1, input, + b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"})); + ops::BinaryOp( + "TestAssign", var2, input, + b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"})); + // Assign to var3, but try to use a colocation group that matches + // the assign of var2. This should fail because assign2 must be on CPU + // (it has a reference edge on var2), and assign3 must be on GPU, + // hence the conflict. + ops::BinaryOp( + "TestAssign", var3, input, + b.opts().WithName("assign3").WithAttr("_class", {"loc:@var2"})); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Cannot colocate nodes 'var3' and 'assign3' because no " + "device type supports both of those nodes and the other " + "nodes colocated with them.")); +} + +TEST_F(PlacerTest, TestColocationAndReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + for (int i = 0; i < 10; ++i) { + // Declare ten variable and assignment pairs. + Node* var = ops::SourceOp("TestVariable", + b.opts().WithName(strings::StrCat("var_", i))); + ops::BinaryOp("TestAssign", var, input, + b.opts().WithName(strings::StrCat("assign_", i))); + } + for (int i = 10; i < 100; ++i) { + // Create a variable colocated with some existing variable, and + // an assignment colocated with a possibly-different variable. + Node* var = ops::SourceOp( + "TestVariable", + b.opts() + .WithName(strings::StrCat("var_", i)) + .WithAttr("_class", {strings::StrCat("loc:@var_", i % 6)})); + ops::BinaryOp( + "TestAssign", var, input, + b.opts() + .WithName(strings::StrCat("assign_", i)) + .WithAttr("_class", {strings::StrCat("loc:@assign_", i % 3)})); + } + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + for (int i = 0; i < 10; ++i) { + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("assign_", i)); + } + for (int i = 10; i < 100; ++i) { + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("assign_", i)); + EXPECT_COLOCATED(g, strings::StrCat("var_", i), + strings::StrCat("var_", i % 6)); + EXPECT_COLOCATED(g, strings::StrCat("assign_", i), + strings::StrCat("assign_", i % 3)); + } +} + +// Test that placement fails when no devices are registered. +TEST_F(PlacerTest, TestEmptyDeviceSet) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + DeviceSet empty; + + Status s = Place(&g, &empty); + EXPECT_TRUE( + StringPiece(s.error_message()).contains("No devices are registered")); +} + +// Test that placement fails when the requested device forces an +// indirect constraint to be violated. +TEST_F(PlacerTest, TestHeterogeneousDeviceSetFailure) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* in = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var")); + ops::BinaryOp("TestAssign", var, in, + b.opts().WithName("assign").WithDevice("/job:b/task:1")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + DeviceSet heterogeneous; + std::unique_ptr gpu( + FakeDevice::MakeGPU("/job:b/replica:0/task:0/device:fakegpu:0")); + heterogeneous.AddDevice(gpu.get()); + std::unique_ptr cpu( + FakeDevice::MakeCPU("/job:b/replica:0/task:1/device:fakecpu:0")); + heterogeneous.AddDevice(cpu.get()); + Status s = Place(&g, &heterogeneous); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("colocated with a group of nodes that required " + "incompatible device")); + + // The error message should contain information that indicates which + // op types have which registered device types. + EXPECT_TRUE(StringPiece(s.error_message()).contains("VariableGPU: FakeGPU")) + << s; + EXPECT_TRUE( + StringPiece(s.error_message()).contains("TestAssign: FakeGPU FakeCPU")) + << s; +} + +// Test that placement fails when an unknown device is requested. +TEST_F(PlacerTest, TestUnknownDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo")); +} + +// Test that placement fails when the combination of partial +// constraints leads to an unknown device. +TEST_F(PlacerTest, TestUnknownMergedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo")); +} + +// Test that placement fails when the previously-assigned device for a +// node is unknown. +TEST_F(PlacerTest, TestUnknownAssignedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name("/job:foo"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Assigned device '/job:foo' does not match any device")); +} + +// Test that placement fails when an op with no registered kernels is +// requested. +TEST_F(PlacerTest, TestNoKernelsRegistered) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableNoKernels", b.opts().WithName("var")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains( + "No OpKernel was registered to support Op 'VariableNoKernels'")); + EXPECT_TRUE( + StringPiece(s.error_message()).contains("")); +} + +// Test that placement fails when a kernel is registered but no known +// device supports it. +TEST_F(PlacerTest, TestNoDevicesRegistered) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableGPU", b.opts().WithName("var")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + DeviceSet cpu_only; + std::unique_ptr cpu( + FakeDevice::MakeCPU("/job:a/replica:0/task:0/device:fakecpu:0")); + cpu_only.AddDevice(cpu.get()); + + Status s = Place(&g, &cpu_only); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("No OpKernel was registered to support " + "Op 'VariableGPU'")); + EXPECT_TRUE(StringPiece(s.error_message()).contains("device='FakeGPU'")); +} + +// Test that placement fails when a requested device is malformed. +TEST_F(PlacerTest, TestMalformedDeviceSpecification) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/foo:bar")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Malformed device specification '/foo:bar'")); +} + +// Test that placement fails when a previously-assigned device is malformed. +TEST_F(PlacerTest, TestMalformedAssignedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name("/foo:bar"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Malformed assigned device '/foo:bar'")); +} + +// Test that placement fails when a device was previously assigned to +// a node, but it does not uniquely identify a particular device. +TEST_F(PlacerTest, TestNonUniqueAssignedDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestInput", b.opts().WithName("in")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + GetNodeByName(g, "in")->set_assigned_device_name("/job:a"); + + Status s = Place(&g); + EXPECT_EQ(error::INTERNAL, s.code()); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Assigned device '/job:a' does not match any device")); +} + +// Test that ops request to be placed on non-existent devices will be relocated +// to existing device of the same type if allow_soft_placement is set. +TEST_F(PlacerTest, TestNonexistentGpuAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestDevice", + b.opts().WithName("in").WithDevice("/device:fakegpu:11")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + TF_EXPECT_OK(Place(&g, &options)); + EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakegpu:0"); +} + +// Test that ops request to be placed on non-existent devices will fail if +// allow_soft_placement is not set. +TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("TestDevice", + b.opts().WithName("in").WithDevice("/device:fakegpu:11")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + Status s = Place(&g, &options); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakegpu:11")); +} + +// Test that placement fails when a node requests an explicit device that is not +// supported by the registered kernels if allow_soft_placement is no set. +TEST_F(PlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableGPU", + b.opts().WithName("var").WithDevice("/device:fakecpu:0")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + Status s = Place(&g, &options); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakecpu:0")); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("no supported kernel for fakecpu devices is available")); +} + +// Test that placement fails when a node requests an explicit device that is not +// supported by the registered kernels if allow_soft_placement is no set. +TEST_F(PlacerTest, TestNonExistentDevice) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableGPU", + b.opts().WithName("var").WithDevice("/job:foo/replica:17")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + Status s = Place(&g, &options); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + LOG(WARNING) << s.error_message(); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("was explicitly assigned to /job:foo/replica:17 " + "but available devices")); +} + +TEST_F(PlacerTest, TestUnsupportedDeviceAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + ops::SourceOp("VariableGPU", + b.opts().WithName("var").WithDevice("/device:fakecpu:0")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + TF_EXPECT_OK(Place(&g, &options)); +} + +// Test that a graph with device type and reference constraints on +// some of the ops will successfully assign nodes to the constrained +// device, and colocate nodes with reference connections. +TEST_F(PlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + // var_gpu has ref output and runs on GPU. + // force_gpu takes var_gpu and requested CPU. + // Verify that both are placed on GPU. + Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu")); + ops::UnaryOp( + "TestDeviceEnforce", var_gpu, + b.opts().WithName("force_gpu").WithDevice("/device:fakecpu:0")); + // var_cpu has ref output and runs on CPU. + // force_cpu takes var_cpu and requested GPU. + // Verify that both are placed on CPU. + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + ops::UnaryOp( + "TestDeviceEnforce", var_cpu, + b.opts().WithName("force_cpu").WithDevice("/device:fakegpu:0")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + SessionOptions options; + options.config.set_allow_soft_placement(true); + TF_EXPECT_OK(Place(&g, &options)); + EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "force_gpu", "FakeGPU"); + EXPECT_COLOCATED(g, "var_gpu", "force_gpu"); + EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "force_cpu", "FakeCPU"); + EXPECT_COLOCATED(g, "var_cpu", "force_cpu"); +} + +// Test that placement fails when two nodes have a reference connection +// constraint, and each node requires a mutually incompatible device. +TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var")); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + ops::BinaryOp("AssignCPU", var, input, b.opts().WithName("assign")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("Cannot colocate nodes 'var' and 'assign'")); +} + +// Test that a generator node follows its consumers (where there are several +// consumer nodes on the same devices). +TEST_F(PlacerTest, TestGeneratorNodeFollowsConsumerNode) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + + // A variable is only on CPU + Node* var1_cpu = + ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu")); + Node* var2_cpu = + ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu")); + + // The constant to be assigned can be on both GPU or CPU. + // + // Because of the heuristic, it gets placed on CPU to avoid a + // copy. + Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in")); + + // The assigns are bound to CPU by the reference edge. + ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1")); + ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2")); + + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "var1_cpu", "in"); + EXPECT_COLOCATED(g, "assign1", "in"); + EXPECT_COLOCATED(g, "var2_cpu", "in"); + EXPECT_COLOCATED(g, "assign2", "in"); +} + +// Test that a generator node does not follow its consumers (where there are +// several consumers on different devices). +TEST_F(PlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + + // A variable is only on CPU + Node* var1_cpu = + ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu")); + Node* var2_cpu = + ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu")); + + // The constant to be assigned can be on both GPU or CPU. + // + // Because of the heuristic, it ought to be on the GPU (cannot be + // co-located with both consumers, so goes to the 'standard' place) + Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in")); + + // The assigns are bound to CPU by the reference edge. + ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1")); + ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2")); + + TF_EXPECT_OK(BuildGraph(b, &g)); + + GetNodeByName(g, "var1_cpu") + ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:1"); + + GetNodeByName(g, "var2_cpu") + ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:2"); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "assign1", "var1_cpu"); + EXPECT_COLOCATED(g, "assign2", "var2_cpu"); + EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc deleted file mode 100644 index c66dc568f6..0000000000 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ /dev/null @@ -1,447 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/simple_graph_execution_state.h" - -#include -#include -#include -#include - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/common_runtime/simple_placer.h" -#include "tensorflow/core/framework/graph.pb_text.h" -#include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/graph/subgraph.h" -#include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/graph/validate.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/device_name_utils.h" -#include "tensorflow/core/util/util.h" - -#ifndef IS_MOBILE_PLATFORM -#include "tensorflow/core/grappler/clusters/utils.h" -#include "tensorflow/core/grappler/clusters/virtual_cluster.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" -#endif // IS_MOBILE_PLATFORM - -namespace tensorflow { - -SimpleGraphExecutionState::SimpleGraphExecutionState( - GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options) - : stateful_placements_(options.stateful_placements), - device_set_(options.device_set), - session_options_(options.session_options), - flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(), - graph_def->library())), - graph_(nullptr) { - // NOTE(mrry): GraphDef does not have a move constructor, so we pass - // a non-const pointer and use `Swap()` to transfer the contents - // without copying. - original_graph_def_.Swap(graph_def); - // TODO(mrry): Publish placement visualizations or handle the log - // placement option. -} - -SimpleGraphExecutionState::~SimpleGraphExecutionState() { - node_name_to_cost_id_map_.clear(); - delete graph_; -} - -/* static */ Status SimpleGraphExecutionState::MakeForBaseGraph( - GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options, - std::unique_ptr* out_state) { - std::unique_ptr ret( - new SimpleGraphExecutionState(graph_def, options)); - - TF_RETURN_IF_ERROR( - AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0)); - // TODO(mrry): Refactor InitBaseGraph() so that we don't have to - // pass an empty BuildGraphOptions (that isn't going to be used when - // place_pruned_graph is false). - if (!ret->session_options_->config.graph_options().place_pruned_graph()) { - TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions())); - } - *out_state = std::move(ret); - return Status::OK(); -} - -/* static */ Status SimpleGraphExecutionState::MakeForPrunedGraph( - const FunctionDefLibrary& func_def_lib, - const SimpleGraphExecutionStateOptions& options, const GraphDef& graph_def, - const BuildGraphOptions& subgraph_options, - std::unique_ptr* out_state, - std::unique_ptr* out_client_graph) { - DCHECK(options.session_options->config.graph_options().place_pruned_graph()); - // NOTE(mrry): This makes a copy of `graph_def`, which is - // regrettable. We could make `GraphDef` objects sharable between - // execution states to optimize pruned graph execution, but since - // this case is primarily used for interactive sessions, we make the - // bet that graph construction is not performance-critical. (Note - // also that the previous version used `Extend()`, which is strictly - // more expensive than copying a `GraphDef`.) - GraphDef temp(graph_def); - std::unique_ptr ret( - new SimpleGraphExecutionState(&temp, options)); - TF_RETURN_IF_ERROR( - AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0)); - TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options)); - TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph)); - *out_state = std::move(ret); - return Status::OK(); -} - -Status SimpleGraphExecutionState::Extend( - const GraphDef& extension_def, - std::unique_ptr* out) const { - GraphDef gdef; - - // 1. Copy the function library. - TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library())); - *gdef.mutable_library() = flib_def_->ToProto(); - - // 2. Build an index of the new node names. - std::unordered_set new_names; - for (const NodeDef& node : extension_def.node()) { - new_names.insert(node.name()); - } - - // 3. Add the non-duplicates from the old graph to the new graph. - // Return an error if the same node name appears in both the - // old graph and the extension. - for (const NodeDef& node : original_graph_def_.node()) { - if (new_names.count(node.name()) == 0) { - *gdef.add_node() = node; - } else { - return errors::InvalidArgument(tensorflow::strings::Printf( - "GraphDef argument to Extend includes node '%s', which was created " - "by a previous call to Create or Extend in this session.", - node.name().c_str())); - } - } - - // 4. Merge the versions field. - int old_node_size = gdef.node_size(); - gdef.mutable_node()->MergeFrom(extension_def.node()); - TF_RETURN_IF_ERROR( - AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size)); - // Merge versions - if (gdef.has_versions()) { - if (gdef.versions().producer() != extension_def.versions().producer()) { - return errors::InvalidArgument( - "Can't extend GraphDef at version ", gdef.versions().producer(), - " with graph at version ", extension_def.versions().producer()); - } - VersionDef* versions = gdef.mutable_versions(); - versions->set_min_consumer(std::max( - versions->min_consumer(), extension_def.versions().min_consumer())); - if (extension_def.versions().bad_consumers_size()) { - // Add new bad_consumers that aren't already marked bad. - // - // Note: This implementation is quadratic time if there are many calls to - // ExtendLocked with many bad consumers. Since this is unlikely, and - // fixing it would require data structures outside of this routine, - // quadratic time it is. - auto* bad_consumers = versions->mutable_bad_consumers(); - const std::unordered_set existing(bad_consumers->begin(), - bad_consumers->end()); - for (const int v : extension_def.versions().bad_consumers()) { - if (existing.find(v) == existing.end()) { - bad_consumers->Add(v); - } - } - } - - } else { - gdef.mutable_versions()->CopyFrom(extension_def.versions()); - } - - // 5. Validate that the final graphdef is valid. - if (gdef.versions().producer() >= 5) { - // Validate the graph: we assume that merging two valid graphs - // should maintain graph validity. - TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_)); - } - - // 6. Add the extension. - SimpleGraphExecutionStateOptions combined_options; - combined_options.device_set = device_set_; - combined_options.session_options = session_options_; - combined_options.stateful_placements = stateful_placements_; - - // NOTE(mrry): `gdef` is no longer valid after the constructor - // executes. - std::unique_ptr new_execution_state( - new SimpleGraphExecutionState(&gdef, combined_options)); - - TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( - &new_execution_state->original_graph_def_, *flib_def_, 0)); - if (!session_options_->config.graph_options().place_pruned_graph()) { - // TODO(mrry): Refactor InitBaseGraph() so that we don't have to - // pass an empty BuildGraphOptions (that isn't going to be used - // when place_pruned_graph is false). - TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions())); - } - *out = std::move(new_execution_state); - - // TODO(mrry): This is likely to be used for non-throughput-sensitive - // interactive workloads, but in future we may want to transfer other - // parts of the placement and/or cost model. - return Status::OK(); -} - -void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) { - for (Node* n : graph->nodes()) { - if (n->op_def().is_stateful()) { - VLOG(2) << "Saving " << n->DebugString(); - stateful_placements_[n->name()] = n->assigned_device_name(); - } - } -} - -void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) { - for (Node* n : graph->nodes()) { - if (n->op_def().is_stateful()) { - auto iter = stateful_placements_.find(n->name()); - if (iter != stateful_placements_.end()) { - n->set_assigned_device_name(iter->second); - VLOG(2) << "Restored " << n->DebugString(); - } - } - } -} - -Status SimpleGraphExecutionState::InitBaseGraph( - const BuildGraphOptions& options) { - const GraphDef* graph_def = &original_graph_def_; - - std::unique_ptr new_graph(new Graph(OpRegistry::Global())); - GraphConstructorOptions opts; - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get())); - for (const Node* n : new_graph->nodes()) { - VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id(); - node_name_to_cost_id_map_[n->name()] = n->cost_id(); - } - if (session_options_ && - session_options_->config.graph_options().place_pruned_graph()) { - // Rewrite the graph before placement. - rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata); - TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( - new_graph.get(), options.feed_endpoints, options.fetch_endpoints, - options.target_nodes, device_set_->client_device()->attributes(), - options.use_function_convention, rewrite_metadata_.get())); - } - - // Save stateful placements before placing. - RestoreStatefulNodes(new_graph.get()); - - GraphOptimizationPassOptions optimization_options; - optimization_options.session_options = session_options_; - optimization_options.graph = &new_graph; - optimization_options.flib_def = flib_def_.get(); - optimization_options.device_set = device_set_; - - TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( - OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); - - SimplePlacer placer(new_graph.get(), device_set_, session_options_); - // TODO(mrry): Consider making the SimplePlacer cancelable. - TF_RETURN_IF_ERROR(placer.Run()); - - TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( - OptimizationPassRegistry::POST_PLACEMENT, optimization_options)); - - SaveStatefulNodes(new_graph.get()); - graph_ = new_graph.release(); - return Status::OK(); -} - -Status SimpleGraphExecutionState::OptimizeGraph( - const BuildGraphOptions& options, std::unique_ptr* optimized_graph) { -#ifndef IS_MOBILE_PLATFORM - if (session_options_->config.graph_options().place_pruned_graph()) { - return errors::InvalidArgument("Can't optimize a pruned graph"); - } - - const RewriterConfig& rewrite_options = - session_options_->config.graph_options().rewrite_options(); - - if (grappler::MetaOptimizerEnabled(rewrite_options)) { - // Adding this functionality in steps. The first step is to make sure - // we don't break dependencies. The second step will be to turn the - // functionality on by default. - grappler::GrapplerItem item; - item.id = "tf_graph"; - graph_->ToGraphDef(&item.graph); - - item.fetch = options.fetch_endpoints; - item.fetch.insert(item.fetch.end(), options.target_nodes.begin(), - options.target_nodes.end()); - - if (!options.feed_endpoints.empty()) { - std::unordered_set feeds; - for (const string& feed : options.feed_endpoints) { - TensorId id = ParseTensorName(feed); - if (id.second != 0) { - return errors::InvalidArgument("Unsupported feed: ", feed); - } - feeds.insert(id.first.ToString()); - } - for (const NodeDef& node : original_graph_def_.node()) { - if (feeds.find(node.name()) == feeds.end()) { - continue; - } - if (node.attr().count("dtype") == 0 || - node.attr().count("shape") == 0) { - return errors::InvalidArgument("Missing node shape or type"); - } - TensorShapeProto shape_proto(node.attr().at("shape").shape()); - // If the shape of the placeholder value is only partially known, we're - // free to use any dimension we want to feed the placeholder. We choose - // 1 to minimize the memory impact. Note that this only matters if an - // optimizer choose to run the graph to build its cost model, which - // doesn't happen (yet) - if (shape_proto.unknown_rank()) { - shape_proto.set_unknown_rank(false); - } - for (auto& dim : *shape_proto.mutable_dim()) { - if (dim.size() < 0) { - dim.set_size(1); - } - } - TensorShape shape(shape_proto); - DataType type = node.attr().at("dtype").type(); - Tensor fake_input(type, shape); - item.feed.emplace_back(node.name(), fake_input); - } - } - - std::unordered_map device_map; - Device* cpu_device = nullptr; - for (const auto& device : device_set_->devices()) { - device_map[device->name()] = - grappler::GetDeviceInfo(device->parsed_name()); - if (device->parsed_name().id == 0 && - StringPiece(device->parsed_name().type) == "CPU" && - device->GetAllocator(AllocatorAttributes()) != nullptr) { - cpu_device = device; - } - } - if (cpu_device == nullptr) { - return errors::Internal( - "Unable to find CPU device needed for constant folding"); - } - grappler::VirtualCluster cluster(device_map); - GraphDef new_graph; - TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer( - item, rewrite_options, cpu_device, &cluster, &new_graph)); - GraphConstructorOptions opts; - opts.allow_internal_ops = true; - optimized_graph->reset(new Graph(OpRegistry::Global())); - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get())); - // The graph conversion sets the requested device names but not the assigned - // device names. However, since at this point the graph is placed TF expects - // an assigned device name for every node. Therefore we copy the requested - // device into the assigned device field. - for (Node* node : optimized_graph->get()->nodes()) { - node->set_assigned_device_name(node->requested_device()); - } - return Status::OK(); - } else { - return errors::InvalidArgument("Meta Optimizer disabled"); - } -#else - return errors::InvalidArgument("Mobile platforms not supported"); -#endif // IS_MOBILE_PLATFORM -} - -Status SimpleGraphExecutionState::BuildGraph( - const BuildGraphOptions& options, std::unique_ptr* out) { - VLOG(1) << "BuildGraph"; - if (!graph_) { - // It is only valid to call this method directly when the original graph - // was created with the option `place_pruned_graph == false`. - return errors::Internal( - "Attempted to prune a graph that has not been fully initialized."); - } - - std::unique_ptr ng; - Status s = OptimizeGraph(options, &ng); - if (!s.ok()) { - // Simply copy the original graph if we couldn't optimize it. - ng.reset(new Graph(flib_def_.get())); - CopyGraph(*graph_, ng.get()); - } - - subgraph::RewriteGraphMetadata rewrite_metadata; - if (session_options_ == nullptr || - !session_options_->config.graph_options().place_pruned_graph()) { - // Extract the subset of the graph that needs to be run, adding feed/fetch - // ops as needed. - TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( - ng.get(), options.feed_endpoints, options.fetch_endpoints, - options.target_nodes, device_set_->client_device()->attributes(), - options.use_function_convention, &rewrite_metadata)); - } else { - // This SimpleGraphExecutionState represents a graph that was - // pruned when this was constructed, so we copy the metadata from - // a member variable. - CHECK(rewrite_metadata_); - rewrite_metadata = *rewrite_metadata_; - } - - CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size()); - CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size()); - - // Make a fresh copy of the function library for the client graph. - std::unique_ptr flib( - new FunctionLibraryDefinition(*flib_def_)); - - // TODO(andydavis): Clarify optimization pass requirements around CostModel. - GraphOptimizationPassOptions optimization_options; - optimization_options.session_options = session_options_; - optimization_options.graph = &ng; - optimization_options.flib_def = flib.get(); - optimization_options.device_set = device_set_; - - TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( - OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); - - // Copy the extracted graph in order to make its node ids dense, - // since the local CostModel used to record its stats is sized by - // the largest node id. - std::unique_ptr dense_copy( - new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types, - rewrite_metadata.fetch_types)); - CopyGraph(*ng, &dense_copy->graph); - - // TODO(vrv): We should check invariants of the graph here. - - *out = std::move(dense_copy); - return Status::OK(); -} - -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h deleted file mode 100644 index 53eef8a07d..0000000000 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ /dev/null @@ -1,209 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ - -#include -#include -#include -#include - -#include "tensorflow/core/common_runtime/build_graph_options.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_set.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/graph/costmodel.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -struct SessionOptions; - -namespace subgraph { -struct RewriteGraphMetadata; -} - -struct SimpleGraphExecutionStateOptions { - const DeviceSet* device_set = nullptr; - const SessionOptions* session_options = nullptr; - // A map from node name to device name, representing the unchangeable - // placement of stateful nodes. - std::unordered_map stateful_placements; -}; - -// A SimpleClientGraph is simply a sub-graph of the full graph as induced by -// BuildGraphOptions. -struct SimpleClientGraph { - explicit SimpleClientGraph(std::unique_ptr flib, - DataTypeVector feed_types, - DataTypeVector fetch_types) - : flib_def(std::move(flib)), - graph(flib_def.get()), - feed_types(std::move(feed_types)), - fetch_types(std::move(fetch_types)) {} - // Each client-graph gets its own function library since optimization passes - // post rewrite for execution might want to introduce new functions. - std::unique_ptr flib_def; - Graph graph; - DataTypeVector feed_types; - DataTypeVector fetch_types; -}; - -// SimpleGraphExecutionState is responsible for generating an -// executable SimpleClientGraph from the original GraphDef that specifies -// the complete graph and from BuildGraphOptions which specifies -// input/output nodes. -// -// An executable Graph differs from a GraphDef by being Placed, -// meaning that each Node is assigned to a single Device in the -// available set. -// -// When SimpleGraphExecutionState is first constructed it instantiates -// a full Graph from the provided GraphDef, and places it, using only -// the static device assignments from the GraphDef. Nodes without are -// currently placed in a very naive way. Since stateful Nodes cannot -// be moved after initial placement, it is important that stateful -// Nodes get sensible initial device assignments in the graph -// definition. -// -// Subsequently, SimpleGraphExecutionState generates a SimpleClientGraph on -// demand, which is a sub-graph of the latest placement of the full -// Graph. MasterSession uses such a SimpleClientGraph to execute one or -// more similar client requests. -// -// SimpleGraphExecutionState is thread-safe. - -class SimpleGraphExecutionState { - public: - virtual ~SimpleGraphExecutionState(); - - // Creates a new `SimpleGraphExecutionState` for the given - // `graph_def`, which represents the entire graph for a session. - // - // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def` - // in an undefined state. If it is necessary to use `*graph_def` - // after this call, make an explicit copy of the graph before - // calling this method. - static Status MakeForBaseGraph( - GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options, - std::unique_ptr* out_state); - - // Creates a new `SimpleGraphExecutionState` and `SimpleClientGraph` - // for the subgraph of `original_graph_def` defined by - // `subgraph_options`. - static Status MakeForPrunedGraph( - const FunctionDefLibrary& func_def_lib, - const SimpleGraphExecutionStateOptions& options, - const GraphDef& original_graph_def, - const BuildGraphOptions& subgraph_options, - std::unique_ptr* out_state, - std::unique_ptr* out_client_graph); - - // Creates a new SimpleGraphExecutionState representing the - // concatenation of this graph, and the graph defined by - // "extension_def". The same name may not be used to define a node - // in both this graph and "extension_def". - // - // If successful, returns OK and the caller takes ownership of "*out". - // Otherwise returns an error and does not modify "*out". - // - // After calling `old_state->Extend()`, `old_state` may no longer be - // used. - // - // NOTE(mrry): This method respects the placement of stateful nodes in - // in *this, but currently does not transfer any other placement - // or cost model information to the new graph. - Status Extend(const GraphDef& extension_def, - std::unique_ptr* out) const; - - // Builds a SimpleClientGraph (a sub-graph of the full graph as induced by - // the Node set specified in "options"). If successful, returns OK - // and the caller takes the ownership of "*out". Otherwise, returns - // an error. - Status BuildGraph(const BuildGraphOptions& options, - std::unique_ptr* out); - - // The graph returned by BuildGraph may contain only the pruned - // graph, whereas some clients may want access to the full graph. - const Graph* full_graph() { - return graph_; - } - - // Returns the node with the given name, or null if it does not exist. - const Node* get_node_by_name(const string& name) const { - NodeNameToCostIdMap::const_iterator iter = - node_name_to_cost_id_map_.find(name); - if (iter != node_name_to_cost_id_map_.end()) { - return graph_->FindNodeId(iter->second); - } else { - return nullptr; - } - } - - // Returns a reference to the current graph_def. Use must - // not extend beyond lifetime of SimpleGrahExecutionState object. - const GraphDef& original_graph_def() { return original_graph_def_; } - - // Returns the map of stateful placements as a map of - // node name to placement string. - std::unordered_map GetStatefulPlacements() const { - return stateful_placements_; - } - - private: - SimpleGraphExecutionState(GraphDef* graph_def, - const SimpleGraphExecutionStateOptions& options); - - Status InitBaseGraph(const BuildGraphOptions& options); - - // Map of placed stateful nodes, i.e. nodes for which is_stateful() - // is true, such as "params" and "queue" nodes. Once placed these - // nodes can not be moved to a different device. Maps node names to - // device names. - std::unordered_map stateful_placements_; // Immutable after - // ctor. - void SaveStatefulNodes(Graph* graph); - void RestoreStatefulNodes(Graph* graph); - - Status OptimizeGraph(const BuildGraphOptions& options, - std::unique_ptr* optimized_graph); - - GraphDef original_graph_def_; // Immutable after ctor. - const DeviceSet* device_set_; // Not owned - const SessionOptions* session_options_; // Not owned - - // Map from name to Node for the full graph in placed_. - NodeNameToCostIdMap node_name_to_cost_id_map_; - - // 'flib_def_' is initialized from the initial graph def's library, - // and may be updated by a graph optimization pass. - std::unique_ptr flib_def_; - - // `rewrite_metadata_` is only set for SimpleGraphExecutionState - // objects created by `MakeForPrunedGraph()`. - std::unique_ptr rewrite_metadata_; - - // The dataflow graph owned by this object. - Graph* graph_; - - TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_ diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc deleted file mode 100644 index 663e62a765..0000000000 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ /dev/null @@ -1,881 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/simple_placer.h" - -#include -#include -#include -#include - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" - -namespace tensorflow { - -namespace { - -// We hoist the conversion from C-style string literal to StringPiece here, -// so that we can avoid the many repeated calls to strlen(). -const StringPiece kColocationAttrNameStringPiece(kColocationAttrName); -const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); - -// Returns a list of devices sorted by preferred type and then name -// from 'devices' whose type is in 'supported_device_types'. This -// function searches the device types in 'supported_device_types' and -// returns the subset of devices that match. -std::vector FilterSupportedDevices( - const std::vector& devices, - const DeviceTypeVector& supported_device_types) { - std::vector filtered_devices; - for (const DeviceType& d : supported_device_types) { - for (Device* device : devices) { - if (DeviceType(device->attributes().device_type()) == d) { - filtered_devices.emplace_back(device); - } - } - } - - auto device_sort = [](const Device* a, const Device* b) { - auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type())); - auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type())); - // First sort by prioritized device type (higher is preferred) and - // then by device name (lexicographically). - if (a_priority != b_priority) { - return a_priority > b_priority; - } - return StringPiece(a->name()) < StringPiece(b->name()); - }; - std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort); - return filtered_devices; -} - -// This class maintains the connected components of a colocation -// constraint graph, and uses this information to assign a satisfying -// device placement to the nodes of the graph. -// -// The typical usage pattern is: -// -// Graph graph = ...; -// DeviceSet device_set = ...; -// ColocationGraph colocation_graph(graph, device_set); -// -// // Add all the nodes of graph to colocation_graph. -// for (Node* node : graph.nodes()) { -// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node)); -// } -// -// // Add one or more colocation constraint. -// Node node_1 = *graph.FindNodeId(...); -// Node node_2 = *graph.FindNodeId(...); -// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2)); -// -// // Assign devices based on the accumulated constraints. -// for (Node* node : graph.nodes()) { -// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node)); -// } -// -// The implementation uses the union-find algorithm to maintain the -// connected components efficiently and incrementally as edges -// (implied by ColocationGraph::ColocateNodes() invocations) are added. -class ColocationGraph { - public: - ColocationGraph(Graph* graph, const DeviceSet* device_set, - bool allow_soft_placement) - : graph_(graph), - device_set_(device_set), - device_types_(device_set->PrioritizedDeviceTypeList()), - allow_soft_placement_(allow_soft_placement) { - members_.resize(graph->num_node_ids()); - } - - // Adds each node of the Graph to this ColocationGraph as a singleton. - // - // NOTE: The implementation assumes that the ids of nodes passed to - // this method are dense and zero-based; the memory used will be linear in - // the largest node ID. - // NOTE: If this method returns an error, *this is left in an undefined - // state. - Status ColocateAllNodes() { - // This maps from a colocation group identifier to the 'root' of that - // colocation group. Note that the keys in this map are StringPiece; the - // actual strings are stored under the NodeDef. The lifetime of this map - // is limited to this ColocateAllNodes() method, and no part of the - // NodeDef trees are changed during the lifetime of this method, so using - // StringPiece as a key is safe. - // - // Also, as a further optimization, we remove the "loc:@" prefix from - // "class" attribute values, when they are used as keys in this table. - // This allows us to use StringPiece values that refer to substrings of - // 'string' values stored in NodeDef attribute lists, as well as StringPiece - // values that refer to 'string' values from NodeDef::name(), without - // performing any string allocations. - std::unordered_map - colocation_group_root; - - for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } - - // When adding the node, identify whether it is part of a - // colocation group. - - // This code is effectively the equivalent of GetNodeAttr() for a string - // array, but it avoids all internal allocations (the allocation of the - // backing store of the std::vector as well as the copies of the - // strings within it). Instead, we combine the query of the colocation - // attribute with the calls to ColocateNodeToGroup. - bool found_spec = false; - const AttrValue* attr_value = - node->attrs().Find(kColocationAttrNameStringPiece); - if (attr_value != nullptr && attr_value->has_list()) { - for (const string& class_spec : attr_value->list().s()) { - StringPiece spec(class_spec); - if (spec.Consume(kColocationGroupPrefixStringPiece)) { - found_spec = true; - TF_RETURN_IF_ERROR( - ColocateNodeToGroup(&colocation_group_root, node, spec)); - } - } - } - - if (!found_spec) { - // If the node does not specify a colocation group, then use the - // name of this node as the colocation group. - TF_RETURN_IF_ERROR( - ColocateNodeToGroup(&colocation_group_root, node, node->name())); - } - } - - return Status::OK(); - } - - Status ColocateNodeToGroup( - std::unordered_map* - colocation_group_root, - Node* node, StringPiece colocation_group) { - const Node*& root_node = (*colocation_group_root)[colocation_group]; - if (root_node == nullptr) { - // This is the first node of the colocation group, so - // designate this node as the 'root' of that colocation group. - root_node = node; - } else { - // Try to colocate the node with the root. If there is an - // error, return it. - Status s = ColocateNodes(*node, *root_node); - if (!s.ok()) { - return AttachDef(s, *node); - } - } - return Status::OK(); - } - - // Merge the (possibly disjoint) sets containing nodes "x" and - // "y". Returns OK if the all nodes in the union of these sets can - // be placed on the same device type. - // - // NOTE: If this method returns an error, *this is left in an undefined - // state. - Status ColocateNodes(const Node& x, const Node& y) { - int x_root = FindRoot(x.id()); - int y_root = FindRoot(y.id()); - return ColocateNodes(x, x_root, y, y_root); - } - - // This overload of ColocateNodes() allows a caller to provide the root node - // ids for the two nodes. For large graphs, this noticeably reduces the - // graph load time. - Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root) { - if (x_root == y_root) { - return Status::OK(); - } - - DCHECK_EQ(x_root, FindRoot(x.id())); - DCHECK_EQ(y_root, FindRoot(y.id())); - - Member& x_root_member = members_[x_root]; - Member& y_root_member = members_[y_root]; - - // Merge the sets by swinging the parent pointer of the smaller - // tree to point to the root of the larger tree. Together with - // path compression in ColocationGraph::FindRoot, this ensures - // that we do not experience pathological performance on graphs - // such as chains. - int new_root, old_root; - if (x_root_member.rank < y_root_member.rank) { - // The tree rooted at x_root is shallower, so connect it to - // y_root. The rank of y_root is unchanged because its new - // child has strictly less rank. - x_root_member.parent = y_root; - new_root = y_root; - old_root = x_root; - } else if (x_root_member.rank > y_root_member.rank) { - // The tree rooted at y_root is shallower, so connect it to - // x_root. The rank of x_root is unchanged because its new - // child has strictly less rank. - y_root_member.parent = x_root; - new_root = x_root; - old_root = y_root; - } else { - // Both trees have the same rank, so break the tie by choosing - // x_root as the new root. - y_root_member.parent = x_root; - // Increment the rank of the tree rooted at x_root, because it - // is now strictly deeper than before. - ++x_root_member.rank; - new_root = x_root; - old_root = y_root; - } - - Member& new_root_member = members_[new_root]; - Member& old_root_member = members_[old_root]; - - // Merge the partial device specifications, and ensure that they are - // compatible. NULL options_ is treated as allowing soft placement. - // TODO(mrry): Consider enriching the error message by pointing - // out which nodes have the explicit partial device - // specifications that caused this conflict. - Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name, - old_root_member.device_name, - allow_soft_placement_); - if (!s.ok()) { - return errors::InvalidArgument("Cannot colocate nodes '", x.name(), - "' and '", y.name(), ": ", - s.error_message()); - } - - // Ensure that the common root has at least one supported device - // type, by computing the intersection of - // new_root_member.supported_device_types and - // old_root_member.supported_device_types. - MergeSupportedDevices(&new_root_member.supported_device_types, - old_root_member.supported_device_types); - if (new_root_member.supported_device_types.empty()) { - return errors::InvalidArgument( - "Cannot colocate nodes '", x.name(), "' and '", y.name(), - "' because no device type supports both of those nodes and the " - "other nodes colocated with them.", - DebugInfo(x_root), DebugInfo(y_root)); - } - - return Status::OK(); - } - - // For the given node, subject to the constraints previously given - // to this ColocationGraph, set its assigned_device_name. Returns OK - // if a satisfying device can be found, otherwise an error. - // - // Note: This method returns a pointer to a field within members_. - // The caller must not use the returned pointer after there is any possibility - // that the members_[i].possible_devices field has been modified. - Status GetDevicesForNode(Node* node, - std::vector** possible_devices) { - *possible_devices = nullptr; - const int node_root = FindRoot(node->id()); - if (!members_[node_root].possible_devices.empty()) { - *possible_devices = &members_[node_root].possible_devices; - return Status::OK(); - } - - // We have not yet computed the possible devices for the - // colocated node set containing 'node', so we do so now using the - // constraints on the root node. - - // "devices" will contain the set of feasible placements for the - // colocated node set containing 'node'. - std::vector devices; - if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) { - // The root node has a (possibly partial) device - // specification, so enumerate the physical devices that - // conform to it. - device_set_->FindMatchingDevices(members_[node_root].device_name, - &devices); - - if (!devices.empty()) { - // Filter devices into those that are compatible with the root - // node (and its children). - devices = FilterSupportedDevices( - devices, members_[node_root].supported_device_types); - } - - // Perform soft placement if allow_soft_placement_ is set. - if (devices.empty() && allow_soft_placement_) { - // The soft_device_name is the same as the node's device name - // without specifying the device type or ID. - DeviceNameUtils::ParsedName soft_device_name = - members_[node_root].device_name; - soft_device_name.type.clear(); - soft_device_name.has_type = false; - soft_device_name.has_id = false; - device_set_->FindMatchingDevices(soft_device_name, &devices); - if (!devices.empty()) { - devices = FilterSupportedDevices( - devices, members_[node_root].supported_device_types); - } - } - - if (devices.empty()) { - // Return an error when a physical device that matches an explicit - // device specification is not found. This ensures that we don't - // assign a node to GPU when the user wanted to force it on CPU. - string debug_info = DebugInfo(node_root); - - DeviceNameUtils::ParsedName specified_device_name; - if (DeviceNameUtils::ParseFullName(node->requested_device(), - &specified_device_name) && - specified_device_name == members_[node_root].device_name) { - // The specified device and merged set device match, and - // will appear in the GraphDef (for debugging), so just - // print the specified device. - std::vector devices_matching_nodedef; - device_set_->FindMatchingDevices(specified_device_name, - &devices_matching_nodedef); - if (devices_matching_nodedef.empty()) { - // Sometimes it is almost impossible to understand the problem - // without a list of available devices. - std::vector device_names; - for (const Device* device : device_set_->devices()) { - device_names.push_back(device->name()); - } - std::sort(device_names.begin(), device_names.end()); - - return errors::InvalidArgument( - "Operation was explicitly assigned to ", - node->requested_device(), " but available devices are [ ", - str_util::Join(device_names, ", "), " ]. Make sure ", - "the device specification refers to a valid device."); - } else if (specified_device_name.has_type) { - return errors::InvalidArgument( - "Could not satisfy explicit device specification '", - node->requested_device(), "' because no supported kernel for ", - specified_device_name.type, " devices is available.", - debug_info); - } else { - return errors::InvalidArgument( - "Could not satisfy explicit device specification '", - node->requested_device(), debug_info); - } - } else { - // The specified device may be a valid device but the - // merged set device is different, so print both. - return errors::InvalidArgument( - "Could not satisfy explicit device specification '", - node->requested_device(), - "' because the node was colocated with a group of nodes that " - "required incompatible device '", - DeviceNameUtils::ParsedNameToString( - members_[node_root].device_name), - "'", debug_info); - } - } - } else { - // The device is completely unspecified, so enumerate the devices that - // support all of the nodes in the set. - if (device_set_->devices().empty()) { - return errors::Internal("No devices are registered"); - } - devices = FilterSupportedDevices( - device_set_->devices(), members_[node_root].supported_device_types); - - if (devices.empty()) { - return errors::InvalidArgument( - "Node had no OpKernel registered to support this operation: ", - "Operation was ", node->type_string(), " and inputs were ", - DataTypeVectorString(node->input_types()), DebugInfo(node_root)); - } - } - - // Cache the result of the possible devices for this node group. - members_[node_root].possible_devices = std::move(devices); - *possible_devices = &members_[node_root].possible_devices; - return Status::OK(); - } - - Status InitializeMembers() { - for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } - Status status = InitializeMember(*node, &members_[node->id()]); - if (!status.ok()) { - return AttachDef(status, *node); - } - } - return Status::OK(); - } - - // Represents a node in the disjoint node set forest, and the - // accumulated constraints on the device used by that node. - struct Member { - Member() = default; - // The id of the node that is the parent of this one, or its own - // id if it is a root. parent <= 0 indicates that this member is invalid. - int parent = -1; - - // A proxy for the depth of the tree that is used to prefer - // connecting smaller trees to larger trees when merging disjoint - // sets. - int rank = 0; - - // The intersection of all device types supported by this node, - // and those of all of its children, in priority order - // of the preferred device. - DeviceTypeVector supported_device_types; - - // The merged form of the device requested for this node, with - // those of all of its children. - DeviceNameUtils::ParsedName device_name; - - // If this node is a root, stores a list of Devices to which this node - // and all of its children have been assigned, or nullptr if this - // has not yet been computed. - std::vector possible_devices; - }; - - // Returns debugging info for the node referred to by 'node_root'. - string DebugInfo(const int node_root) { - string text( - "\nColocation Debug Info:\n" - "Colocation group had the following types and devices: "); - - // If this node is part of a colocation group, then we want to - // collect the mapping of ops to supported devices, so that - // the user can see why an unsatisfiable placement occurred. - - std::unordered_map type_to_devices; - int num_nodes_found = 0; - - for (const Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } - int id = node->id(); - if (FindRoot(id) != node_root) { - continue; - } - ++num_nodes_found; - const string& op_type = node->type_string(); - string devices_registered; - for (const auto& device_type : members_[id].supported_device_types) { - strings::StrAppend(&devices_registered, DeviceTypeString(device_type), - " "); - } - - type_to_devices[op_type] = std::move(devices_registered); - } - - for (const auto& td : type_to_devices) { - strings::StrAppend(&text, "\n", td.first, ": ", td.second); - } - - if (num_nodes_found <= 1) { - text.clear(); - } - return text; - } - - Status InitializeMember(const Node& node, Member* member) { - const int id = node.id(); - DCHECK_GE(id, 0); - member->parent = id; - TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( - device_types_, node.def(), &member->supported_device_types)); - - if (node.has_assigned_device_name()) { - // This node has already been assigned to a device, so we - // respect this placement, after sanity-checking it. The - // device_name and supported_device_types for this node reflect - // the assigned device, so any nodes colocated with this node - // will be assigned to the same device (assuming this is - // possible). - // NOTE: Since any assignment must have been performed by - // the TensorFlow runtime, we consider errors in this branch to - // be INTERNAL. - const string& assigned_device_name = node.assigned_device_name(); - if (!DeviceNameUtils::ParseFullName(assigned_device_name, - &member->device_name)) { - return errors::Internal("Malformed assigned device '", - assigned_device_name, "'"); - } - const Device* assigned_device = - device_set_->FindDeviceByName(assigned_device_name); - if (assigned_device == nullptr) { - return errors::Internal("Assigned device '", assigned_device_name, - "' does not match any device"); - } - - for (const DeviceType& d : member->supported_device_types) { - if (DeviceType(assigned_device->attributes().device_type()) == d) { - return Status::OK(); - } - } - - return errors::Internal("Assigned device '", assigned_device_name, - "' does not have registered OpKernel support " - "for ", - node.type_string()); - } else { - // This node has not yet been assigned to a device, so we - // calculate any constraints due to the set of registered - // kernels and any (partial) user-provided device specification - // in the NodeDef. - - // If no kernels are registered for this op type, fail with an error. - if (member->supported_device_types.empty()) { - std::set registered_device_types; - for (Device* d : device_set_->devices()) { - registered_device_types.insert(d->device_type()); - } - return errors::InvalidArgument( - "No OpKernel was registered to support Op '", node.type_string(), - "' with these attrs. Registered devices: [", - str_util::Join(registered_device_types, ","), - "], Registered kernels:\n", - KernelsRegisteredForOp(node.type_string())); - } - - // If the NodeDef contains a device, then we interpret it as a - // (partial) device specification. - if (!node.requested_device().empty()) { - // The user has specified a device in the NodeDef, try to find a - // valid device matching their specification in the set of - // devices. - // NOTE: The full name may specify a device that is not in - // n.supported_device_types(), but we check that in AssignDevice(). - if (!DeviceNameUtils::ParseFullName(node.requested_device(), - &member->device_name)) { - return errors::InvalidArgument("Malformed device specification '", - node.requested_device(), "'"); - } - } - } - return Status::OK(); - } - - // Updates target to contain the intersection of the device types in - // "target" and "other". - static void MergeSupportedDevices(DeviceTypeVector* target, - const DeviceTypeVector& other) { - DeviceTypeVector temp = *target; - target->clear(); - - // Iterate in priority order. - for (const DeviceType& device_type : temp) { - bool found = false; - for (const DeviceType& other_device_type : other) { - if (device_type == other_device_type) { - found = true; - break; - } - } - if (found) { - target->push_back(device_type); - } - } - } - - // Returns the root node of the disjoint tree to which the node with the - // given id is connected. - int FindRoot(int node_id) { - Member& member = members_[node_id]; - - int parent = member.parent; - DCHECK_GE(parent, 0); - - if (parent != node_id) { - // NOTE: Compress paths from node_id to its root, so that future - // calls to FindRoot and ColocateNodes are more efficient. - int root = FindRoot(parent); - if (parent != root) { - parent = root; - member.parent = root; - } - } - - DCHECK_GE(parent, 0); - return parent; - } - - Graph* const graph_; // Not owned. - std::vector members_; - const DeviceSet* device_set_; // Not owned. - const std::vector device_types_; - const bool allow_soft_placement_; -}; - -// Returns true if the node has no inputs and produces outputs -// that are consumed by a single node. -// -// TODO(vrv): Currently this handles only nodes with one output, but -// this could be extended to handle the case where a node has many -// outputs that are connected to nodes in the same colocation group. -bool IsGeneratorNode(const Node* node) { - return node->num_inputs() == 0 && node->num_outputs() == 1 && - !IsRefType(node->output_type(0)); -} - -} // namespace - -SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices, - const SessionOptions* options) - : graph_(graph), - devices_(devices), - options_(options), - log_device_placement_(options != nullptr && - options->config.log_device_placement()) {} - -SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices) - : SimplePlacer(graph, devices, nullptr) {} - -SimplePlacer::~SimplePlacer() {} - -Status SimplePlacer::Run() { - if (devices_->devices().empty()) { - return errors::FailedPrecondition("No devices are registered"); - } - - ColocationGraph colocation_graph( - graph_, devices_, - options_ == nullptr || options_->config.allow_soft_placement()); - - TF_RETURN_IF_ERROR(colocation_graph.InitializeMembers()); - - // 1. First add all of the nodes. Note that steps (1) and (2) - // requires two passes over the nodes because the graph (and hence - // the constraints) may not be acyclic. - TF_RETURN_IF_ERROR(colocation_graph.ColocateAllNodes()); - - // 2. Enumerate the constraint edges, and use them to update the disjoint - // node set. - - // If `node` has an input edge with reference type, add an - // edge from the source of that edge to `node`. - for (const Edge* edge : graph_->edges()) { - if (edge->IsControlEdge()) { - continue; - } - Node* src = edge->src(); - Node* dst = edge->dst(); - DataType input_type = dst->input_type(edge->dst_input()); - if (input_type == DT_RESOURCE || IsRefType(input_type)) { - int src_root_id = colocation_graph.FindRoot(src->id()); - int dst_root_id = colocation_graph.FindRoot(dst->id()); - auto& src_root = colocation_graph.members_[src_root_id]; - auto& dst_root = colocation_graph.members_[dst_root_id]; - // If both the source node and this node have partially - // specified a device, then 'node's device should be - // cleared: the reference edge forces 'node' to be on the - // same device as the source node. - const auto& source_parsed_name = src_root.device_name; - const auto& dest_parsed_name = dst_root.device_name; - if (DeviceNameUtils::HasSomeDetails(source_parsed_name) && - DeviceNameUtils::HasSomeDetails(dest_parsed_name)) { - // Ignore a specified device for 'dst' if the two names were - // incompatible. - if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name, - dest_parsed_name)) { - if (log_device_placement_) { - LOG(INFO) << "Ignoring device specification " - << DeviceNameUtils::ParsedNameToString(dest_parsed_name) - << " for node '" << dst->name() - << "' because the input edge from '" << src->name() - << "' is a reference connection and already has a device " - "field set to " - << DeviceNameUtils::ParsedNameToString( - source_parsed_name); - } - - // Make 'dst' colocated with the source - dst_root.device_name = source_parsed_name; - } else { - bool source_subset_of_dest = DeviceNameUtils::IsSpecification( - source_parsed_name, dest_parsed_name); - bool dest_subset_of_source = DeviceNameUtils::IsSpecification( - dest_parsed_name, source_parsed_name); - - if (source_subset_of_dest && !dest_subset_of_source) { - src_root.device_name = dest_parsed_name; - } else { - dst_root.device_name = source_parsed_name; - } - } - } - - Status status = - colocation_graph.ColocateNodes(*src, src_root_id, *dst, dst_root_id); - if (!status.ok()) { - return AttachDef( - errors::InvalidArgument("Nodes were connected by a " - "reference connection (requiring them to " - "be on the same device), but the two nodes " - "were assigned two different devices: ", - status.error_message()), - *dst); - } - } - } - - // 3. For each node, assign a device based on the constraints in the - // disjoint node set. - std::vector second_pass; - for (Node* node : graph_->op_nodes()) { - // The graph may have come pre-populated by the framework with assigned - // devices (e.g., for stateful placements), so the placer should not try to - // place nodes that are already placed. - if (node->has_assigned_device_name()) { - LogDeviceAssignment(node); - continue; - } - - // Heuristic A: prefer to place "generators" with their only - // consumers. - // - // If this is a node with no inputs and one output, we save - // this for a second pass, so that the consumer's placement - // is chosen. - if (IsGeneratorNode(node)) { - second_pass.push_back(node); - continue; - } - - std::vector* devices; - Status status = colocation_graph.GetDevicesForNode(node, &devices); - if (!status.ok()) { - return AttachDef( - errors::InvalidArgument("Cannot assign a device for operation '", - node->name(), "': ", status.error_message()), - *node); - } - - // Returns the first device in sorted devices list so we will always - // choose the same device. - // - // TODO(vrv): Factor this assignment out into a pluggable - // algorithm, so that SimplePlacer is responsible for enforcing - // preconditions and we can experiment with other algorithms when - // given a choice of devices. Once we have a better idea of the - // types of heuristics we want to use and the information needed - // to perform good placement we can add an interface for this. - int assigned_device = -1; - - // Heuristic B: If the node only operates on metadata, not data, - // then it is desirable to place that metadata node with its - // input. - if (IsMetadata(node)) { - // Make sure that the input device type is in the list of supported - // device types for this node. - const Node* input = (*node->in_edges().begin())->src(); - // TODO(vrv): if the input is empty, consider postponing this - // node's assignment to the second pass, so that we handle the - // case where a metadata node's input comes from a backedge - // of a loop. - if (CanAssignToDevice(input->assigned_device_name(), *devices)) { - assigned_device = input->assigned_device_name_index(); - } - } - - // Provide the default, if necessary. - if (assigned_device == -1) { - assigned_device = graph_->InternDeviceName((*devices)[0]->name()); - } - - AssignAndLog(assigned_device, node); - } - - // 4. Perform a second pass assignment for those nodes explicitly - // skipped during the first pass. - for (Node* node : second_pass) { - std::vector* devices; - Status status = colocation_graph.GetDevicesForNode(node, &devices); - if (!status.ok()) { - return AttachDef( - errors::InvalidArgument("Cannot assign a device for operation '", - node->name(), "': ", status.error_message()), - *node); - } - - int assigned_device = -1; - - // Heuristic A application. - if (IsGeneratorNode(node)) { - const Node* output = (*node->out_edges().begin())->dst(); - int output_device_name = output->assigned_device_name_index(); - - const bool consumers_on_same_device = std::all_of( - node->out_edges().begin(), node->out_edges().end(), - [output_device_name](const Edge* e) { - return e->dst()->assigned_device_name_index() == output_device_name; - }); - - if (consumers_on_same_device && - CanAssignToDevice(output->assigned_device_name(), *devices)) { - assigned_device = output_device_name; - } - } - - // Provide the default, if necessary. - if (assigned_device == -1) { - assigned_device = graph_->InternDeviceName((*devices)[0]->name()); - } - - AssignAndLog(assigned_device, node); - } - - return Status::OK(); -} - -bool SimplePlacer::CanAssignToDevice( - const string& candidate_device_name, - const std::vector& devices) const { - if (!candidate_device_name.empty()) { - // 'devices' lists the set of devices that the placer or the user has - // constrained the operation to. "candidate_device_name" must - // refer to a concrete Device that is in the list of 'devices'. - const Device* other_device = - devices_->FindDeviceByName(candidate_device_name); - if (std::find(devices.begin(), devices.end(), other_device) != - devices.end()) { - return true; - } - } - - return false; -} - -void SimplePlacer::AssignAndLog(int assigned_device, Node* node) const { - node->set_assigned_device_name_index(assigned_device); - LogDeviceAssignment(node); -} - -void SimplePlacer::LogDeviceAssignment(const Node* node) const { - // Log placement if log_device_placement is set. - if (log_device_placement_) { - printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(), - node->assigned_device_name().c_str()); - LOG(INFO) << node->name() << ": " - << "(" << node->type_string() << ")" - << node->assigned_device_name(); - } -} - -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h deleted file mode 100644 index 9c63cef40b..0000000000 --- a/tensorflow/core/common_runtime/simple_placer.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ -#define TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ - -#include -#include - -#include "tensorflow/core/common_runtime/device_set.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/util/device_name_utils.h" - -namespace tensorflow { - -// A placement algorithm that assigns the nodes of the given Graph to -// devices the given DeviceSet, respecting the following constraints: -// -// 1. Existing device assignments remain unchanged. -// 2. Requested (partial or complete) device specifications given by device name -// for each node are granted. -// 3. Nodes connected by edges of a reference type are colocated on -// the same device. -// 4. Given nodes "A" and "B", if node "B" has a colocation group -// "@loc:A", nodes "A" and "B" will be colocated on the same device. -// -// The implementation builds a constraint graph with the same set of -// nodes, and edges that represent colocation constraints between -// nodes. Each connected component in the resulting constraint graph -// is then assigned to a set of valid devices. -// -// Run() will finally assign the device to each node given the list of -// possible devices. -// -// TODO(mrry): "Soft" constraints, such as "place node 'x' as close as -// possible to node 'y' while respecting the other constraints"? -// TODO(mrry): Create a common interface for this and the other -// placement algorithms so that they may be injected into the graph -// builder. -class SimplePlacer { - public: - // A map from graph node names to numerical IDs (in a Graph object). - typedef std::unordered_map NodeNameToIdMap; - - // Creates an instance of the SimplePlacer algorithm for the given - // Graph "graph" (nodes in which may or may not be assigned) on the - // given DeviceSet "devices". - // - // The "graph", and "devices" pointer arguments - // are borrowed by this SimplePlacer, and must outlive it. - SimplePlacer(Graph* graph, const DeviceSet* devices, - const SessionOptions* options); - - SimplePlacer(Graph* graph, const DeviceSet* devices); - - ~SimplePlacer(); - - // Assigns each node in this SimplePlacer's graph to a device in its - // set of devices. - // - // This method is not thread-safe. - // Run() may be invoked at most once. - Status Run(); - - private: - // Returns true if the device type of 'candidate_device_name' is - // found in 'devices'. - bool CanAssignToDevice(const string& candidate_device_name, - const std::vector& devices) const; - - // Assigns 'node's devices to 'assigned_device', and logs the - // placement if the SessionOptions entry in 'options_' requests it. - void AssignAndLog(int assigned_device, Node* node) const; - void LogDeviceAssignment(const Node* node) const; - - Graph* const graph_; // Not owned. - const DeviceSet* const devices_; // Not owned. - const SessionOptions* options_; // Not owned. - const bool log_device_placement_; - - TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_ diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc deleted file mode 100644 index 967bee63a1..0000000000 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ /dev/null @@ -1,1285 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/simple_placer.h" - -#include -#include -#include -#include - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/device_set.h" -#include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/lib/core/error_codes.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { - -namespace { - -//////////////////////////////////////////////////////////////////////////////// -// -// Op, kernel, and device registrations to set up the environment. -// -// The SimplePlacer uses information about the op (input types), -// kernel (device constraints), and available devices to make -// placement decisions. To avoid depending on the full runtime, we -// define dummy implementations of these, and register them with the -// runtime. -// -//////////////////////////////////////////////////////////////////////////////// - -// A dummy OpKernel that is used to register ops on different devices. -class DummyOp : public OpKernel { - public: - explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {} - void Compute(OpKernelContext* context) override {} -}; - -// A fake device that has specific device attributes, used to simulate -// the presence of a CPU or a GPU (without depending on that part of -// the runtime. -class FakeDevice : public Device { - private: - explicit FakeDevice(const DeviceAttributes& device_attributes) - : Device(nullptr, device_attributes) {} - - public: - Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } - - Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } - - static std::unique_ptr MakeCPU(const string& name) { - DeviceAttributes device_attributes; - device_attributes.set_name(name); - device_attributes.set_device_type(DeviceType("FakeCPU").type()); - return std::unique_ptr(new FakeDevice(device_attributes)); - } - - static std::unique_ptr MakeGPU(const string& name) { - DeviceAttributes device_attributes; - device_attributes.set_name(name); - device_attributes.set_device_type(DeviceType("FakeGPU").type()); - return std::unique_ptr(new FakeDevice(device_attributes)); - } -}; - -class DummyFactory : public DeviceFactory { - public: - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override { - return Status::OK(); - } -}; - -// Device order now depends on the registration of devices, not a fixed -// value in device_set.cc. To avoid the need to link in the real CPU and GPU -// devices into this test, we create fake devices and registrations that -// can stand-in for the real devices for the purposes of testing placement -// and ordering. -REGISTER_LOCAL_DEVICE_FACTORY("FakeCPU", DummyFactory); -REGISTER_LOCAL_DEVICE_FACTORY("FakeGPU", DummyFactory, 51); - -// Register the following ops so they can be added to a Graph, and -// kernels so that they can be placed on particular device types. -REGISTER_OP("TestVariable").Output("o: Ref(float)"); -REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestVariable").Device("FakeGPU"), DummyOp); - -REGISTER_OP("VariableCPU").Output("o: Ref(float)"); -REGISTER_KERNEL_BUILDER(Name("VariableCPU").Device("FakeCPU"), DummyOp); - -REGISTER_OP("VariableGPU").Output("o: Ref(float)"); -REGISTER_KERNEL_BUILDER(Name("VariableGPU").Device("FakeGPU"), DummyOp); - -REGISTER_OP("VariableNoKernels").Output("o: Ref(float)"); - -REGISTER_OP("TestAdd").Input("a: float").Input("b: float").Output("o: float"); -REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestAdd").Device("FakeGPU"), DummyOp); - -REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); -REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestRelu").Device("FakeGPU"), DummyOp); - -REGISTER_OP("ReluCPU").Input("i: float").Output("o: float"); -REGISTER_KERNEL_BUILDER(Name("ReluCPU").Device("FakeCPU"), DummyOp); - -REGISTER_OP("ReluGPU").Input("i: float").Output("o: float"); -REGISTER_KERNEL_BUILDER(Name("ReluGPU").Device("FakeGPU"), DummyOp); - -REGISTER_OP("TestAssign").Input("i: Ref(float)").Input("v: float"); -REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestAssign").Device("FakeGPU"), DummyOp); - -REGISTER_OP("AssignCPU").Input("i: Ref(float)").Input("v: float"); -REGISTER_KERNEL_BUILDER(Name("AssignCPU").Device("FakeCPU"), DummyOp); - -REGISTER_OP("AssignGPU").Input("i: Ref(float)").Input("v: float"); -REGISTER_KERNEL_BUILDER(Name("AssignGPU").Device("FakeGPU"), DummyOp); - -REGISTER_OP("TestInput").Output("a: float").Output("b: float"); -REGISTER_KERNEL_BUILDER(Name("TestInput").Device("FakeCPU"), DummyOp); - -// Op producing an output that can be placed on CPU or GPU. -REGISTER_OP("TestCPUGPUOutput").Output("a: float"); -REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestCPUGPUOutput").Device("FakeGPU"), DummyOp); - -REGISTER_OP("TestGPUOutput").Output("a: float"); -REGISTER_KERNEL_BUILDER(Name("TestGPUOutput").Device("FakeGPU"), DummyOp); - -REGISTER_OP("TestDevice").Output("a: float").Output("b: float"); -REGISTER_KERNEL_BUILDER(Name("TestDevice").Device("FakeGPU"), DummyOp); - -REGISTER_OP("TestDeviceEnforce").Input("a: Ref(float)").Output("b: float"); -REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp); - -REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp); - -//////////////////////////////////////////////////////////////////////////////// -// -// A SimplePlacerTest method has three phases: -// -// 1. Build a TensorFlow graph, with no (or partial) device assignments. -// 2. Attempt to compute a placement using the SimplePlacer. -// 3. EITHER: test that the constraints implied by the graph are respected; -// or that an appropriate error was reported. -// -//////////////////////////////////////////////////////////////////////////////// -class SimplePlacerTest : public ::testing::Test { - protected: - SimplePlacerTest() { - // Build a set of 10 GPU and 10 CPU devices. - // NOTE: this->local_devices_ owns the device objects; - // this->devices_ contains borrowed pointers to the device - // objects. - for (int i = 0; i < 10; ++i) { - local_devices_.emplace_back(FakeDevice::MakeCPU( - strings::StrCat("/job:a/replica:0/task:0/device:fakecpu:", i))); - devices_.AddDevice(local_devices_.back().get()); - // Insert the GPUs in reverse order. - local_devices_.emplace_back(FakeDevice::MakeGPU( - strings::StrCat("/job:a/replica:0/task:0/device:fakegpu:", 9 - i))); - devices_.AddDevice(local_devices_.back().get()); - } - } - - // Builds the given graph, and (if successful) indexes the node - // names for use in placement, and later lookup. - Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { - TF_RETURN_IF_ERROR(builder.ToGraph(out_graph)); - nodes_by_name_.clear(); - for (Node* node : out_graph->nodes()) { - nodes_by_name_[node->name()] = node->id(); - } - return Status::OK(); - } - - // Invokes the SimplePlacer on "graph". If no DeviceSet is specified, the - // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices). - // - // REQUIRES: "*graph" was produced by the most recent call to BuildGraph. - Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) { - SimplePlacer placer(graph, devices, options); - return placer.Run(); - } - - Status Place(Graph* graph, DeviceSet* devices) { - return Place(graph, devices, nullptr); - } - - Status Place(Graph* graph, SessionOptions* options) { - return Place(graph, &devices_, options); - } - - Status Place(Graph* graph) { return Place(graph, &devices_, nullptr); } - - // Returns the node in "graph" with the given name. - // - // REQUIRES: "graph" was produced by the most recent call to BuildGraph. - Node* GetNodeByName(const Graph& graph, const string& name) { - const auto search = nodes_by_name_.find(name); - CHECK(search != nodes_by_name_.end()) << "Unknown node name: " << name; - return graph.FindNodeId(search->second); - } - - protected: - std::vector> local_devices_; - DeviceSet devices_; - SimplePlacer::NodeNameToIdMap nodes_by_name_; - - Status ReferenceTestHelper(const string& variable_op_type, - const string& assign_op_type, - const DeviceType& expected_device_type); -}; - -#define EXPECT_COLOCATED(g, name_a, name_b) \ - do { \ - Graph& g_ = (g); \ - EXPECT_EQ(GetNodeByName(g_, (name_a))->assigned_device_name(), \ - GetNodeByName(g_, (name_b))->assigned_device_name()); \ - } while (0) - -#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \ - do { \ - Graph& g_ = (g); \ - EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \ - GetNodeByName(g_, (name_b))->assigned_device_name()); \ - } while (0) - -#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \ - EXPECT_EQ(DeviceType(expected_device_type).type(), \ - devices_ \ - .FindDeviceByName( \ - GetNodeByName((g), (name))->assigned_device_name()) \ - ->attributes() \ - .device_type()) - -#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \ - EXPECT_TRUE(StringPiece(GetNodeByName((g), (name))->assigned_device_name()) \ - .contains(device_substr)) - -// Test that a graph with no constraints will successfully assign nodes to the -// "best available" device (i.e. prefer GPU over CPU). -TEST_F(SimplePlacerTest, TestNoConstraints) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - ops::UnaryOp("TestRelu", ops::NodeOut(input, 0), b.opts().WithName("n1")); - ops::UnaryOp("TestRelu", ops::NodeOut(input, 1), b.opts().WithName("n2")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); - EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU"); - EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU"); -} - -// Test that a graph with device type and reference constraints on -// some of the ops will successfully assign nodes to the constrained -// device, and colocate nodes with reference connections. -TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); - ops::BinaryOp("AssignCPU", var_cpu, input, b.opts().WithName("assign_cpu")); - Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu")); - ops::BinaryOp("AssignGPU", var_gpu, input, b.opts().WithName("assign_gpu")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); - EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); - EXPECT_DEVICE_TYPE(g, "assign_cpu", "FakeCPU"); - EXPECT_COLOCATED(g, "var_cpu", "assign_cpu"); - EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU"); - EXPECT_DEVICE_TYPE(g, "assign_gpu", "FakeGPU"); - EXPECT_COLOCATED(g, "var_gpu", "assign_gpu"); -} - -TEST_F(SimplePlacerTest, TestMetadataColocatedWithInput) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); - - // Normally, shape has a GPU implementation and would be placed - // on GPU. However, because it is a metadata operation, it is - // placed on CPU to avoid transferring the data from CPU to GPU. - ops::UnaryOp("Shape", var_cpu, b.opts().WithName("shape_op")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); - EXPECT_DEVICE_TYPE(g, "shape_op", "FakeCPU"); - EXPECT_COLOCATED(g, "var_cpu", "shape_op"); -} - -// Heuristic A implements "Island fusing": if a node only generates -// an output and it has only one consumer, we place the node -// with its consumer. -TEST_F(SimplePlacerTest, TestHeuristicGeneratorFollowsSingleConsumer) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - - // A variable is only on CPU - Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); - - // The constant to be assigned can be on both GPU or CPU. - // - // Because of the heuristic, it gets placed on CPU to avoid a - // copy. - Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in")); - - // The assign is bound to CPU by the reference edge. - ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign")); - - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_COLOCATED(g, "var_cpu", "in"); - EXPECT_COLOCATED(g, "assign", "in"); -} - -TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - - // A variable is only on CPU - Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); - - // The constant to be assigned can only be on GPU. - // - // The heuristic to place the generator with its consumer does - // not apply since the consumer's device is not in the list - // of valid devices for the generator. - Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in")); - - // The assign is bound to CPU by the reference edge. - ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign")); - - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); - EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); - EXPECT_COLOCATED(g, "var_cpu", "assign"); -} - -TEST_F(SimplePlacerTest, TestIgnoreGeneratorHeuristicIfWrongPartialDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - - // A variable is only on CPU - Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); - - // The constant to be assigned can be on CPU or GPU, but is explicitly - // placed on CPU:1. - // - // The heuristic to place the generator with its consumer does - // not apply since the consumer's device is not in the list - // of valid devices for the generator. - Node* input = - ops::SourceOp("TestCPUGPUOutput", - b.opts().WithName("in").WithDevice("/device:fakecpu:1")); - - // The assign is bound to CPU by the reference edge. - ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign")); - - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); - EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakecpu:1"); - EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); - EXPECT_COLOCATED(g, "var_cpu", "assign"); - EXPECT_DEVICE_CONTAINS(g, "var_cpu", "/device:fakecpu:0"); -} - -// Test that a graph with partial device specifications on the ops -// will successfully -TEST_F(SimplePlacerTest, TestPartialSpec) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:a")); - ops::SourceOp("TestVariable", - b.opts().WithName("var").WithDevice("/job:a")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); - EXPECT_DEVICE_CONTAINS(g, "in", "/job:a"); - EXPECT_DEVICE_TYPE(g, "var", "FakeGPU"); - EXPECT_DEVICE_CONTAINS(g, "var", "/job:a"); -} - -// Test that a node with a pre-assigned device is not relocated. -TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - GetNodeByName(g, "in")->set_assigned_device_name( - "/job:a/replica:0/task:0/device:fakecpu:7"); - - TF_EXPECT_OK(Place(&g)); - EXPECT_EQ("/job:a/replica:0/task:0/device:fakecpu:7", - GetNodeByName(g, "in")->assigned_device_name()); -} - -// Test that a graph with partial device specifications for CPU-only ops -// will be relocated to CPU. -TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", - b.opts().WithName("in").WithDevice("/device:fakegpu:0")); - ops::SourceOp("TestVariable", - b.opts().WithName("var").WithDevice("/device:fakegpu:0")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.set_allow_soft_placement(true); - TF_EXPECT_OK(Place(&g, &options)); - EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); - EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakecpu"); - EXPECT_DEVICE_TYPE(g, "var", "FakeGPU"); - EXPECT_DEVICE_CONTAINS(g, "var", "/device:fakegpu:0"); -} - -// Test that a node with an assigned GPU device but has not registered -// OpKernel will fail. -TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - GetNodeByName(g, "in")->set_assigned_device_name( - "/job:a/replica:0/task:0/device:fakegpu:0"); - - Status s = Place(&g); - EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains( - "Assigned device '/job:a/replica:0/task:0/device:fakegpu:0' " - "does not have registered OpKernel support for TestInput")); -} - -// Test that graphs with reference connections are correctly placed. - -// Build a graph containing a Variable op of "variable_op_type" and an -// Assign op of "assign_op_type", and expect all of the ops to be -// placed on a device of type "expected_device_type". -Status SimplePlacerTest::ReferenceTestHelper( - const string& variable_op_type, const string& assign_op_type, - const DeviceType& expected_device_type) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - // Build ten variable-and-assignment pairs. - for (int i = 0; i < 10; ++i) { - Node* var = ops::SourceOp(variable_op_type, - b.opts().WithName(strings::StrCat("var_", i))); - ops::BinaryOp(assign_op_type, var, input, - b.opts().WithName(strings::StrCat("assign_", i))); - } - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_RETURN_IF_ERROR(Place(&g)); - - for (int i = 0; i < 10; ++i) { - EXPECT_COLOCATED(g, strings::StrCat("var_", i), - strings::StrCat("assign_", i)); - EXPECT_DEVICE_TYPE(g, strings::StrCat("var_", i), expected_device_type); - EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type); - } - - return Status::OK(); -} - -// Test all 2^3 combinations of Variable and Assignment op types -// (unconstrained, CPU-only, and GPU-only). -TEST_F(SimplePlacerTest, TestReferenceConnection) { - Status s; - TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", "FakeGPU")); - TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", "FakeCPU")); - TF_EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignGPU", "FakeGPU")); - TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "TestAssign", "FakeCPU")); - TF_EXPECT_OK(ReferenceTestHelper("VariableCPU", "AssignCPU", "FakeCPU")); - { - Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", "FakeCPU"); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("no device type supports both of those nodes")); - } - TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", "FakeGPU")); - { - Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", "FakeCPU"); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("no device type supports both of those nodes")); - } - TF_EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", "FakeGPU")); -} - -// Handle-using dummy variable ops. -REGISTER_OP("TestHandleVariable").Output("o: resource"); -REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestHandleVariable").Device("FakeGPU"), DummyOp); - -REGISTER_OP("HandleVariableCPU").Output("o: resource"); -REGISTER_KERNEL_BUILDER(Name("HandleVariableCPU").Device("FakeCPU"), DummyOp); - -REGISTER_OP("HandleVariableGPU").Output("o: resource"); -REGISTER_KERNEL_BUILDER(Name("HandleVariableGPU").Device("FakeGPU"), DummyOp); - -REGISTER_OP("TestHandleAssign").Input("i: resource").Input("v: float"); -REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeCPU"), DummyOp); -REGISTER_KERNEL_BUILDER(Name("TestHandleAssign").Device("FakeGPU"), DummyOp); - -REGISTER_OP("HandleAssignCPU").Input("i: resource").Input("v: float"); -REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device("FakeCPU"), DummyOp); - -REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float"); -REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp); - -// Tests all combinations of resource handles and ops using them. -TEST_F(SimplePlacerTest, TestResourceHandle) { - auto handle_test = [this](const string& var_op_name, - const string& use_op_name, DeviceType device) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - Node* var = ops::SourceOp(var_op_name, b.opts().WithName("var")); - ops::BinaryOp(use_op_name, var, input, b.opts().WithName("assign")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_RETURN_IF_ERROR(Place(&g)); - - EXPECT_COLOCATED(g, "var", "assign"); - EXPECT_DEVICE_TYPE(g, "var", device); - EXPECT_DEVICE_TYPE(g, "assign", device); - return Status::OK(); - }; - TF_EXPECT_OK( - handle_test("TestHandleVariable", "TestHandleAssign", "FakeGPU")); - TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignCPU", "FakeCPU")); - TF_EXPECT_OK(handle_test("TestHandleVariable", "HandleAssignGPU", "FakeGPU")); - TF_EXPECT_OK(handle_test("HandleVariableCPU", "TestHandleAssign", "FakeCPU")); - TF_EXPECT_OK(handle_test("HandleVariableCPU", "HandleAssignCPU", "FakeCPU")); - TF_EXPECT_OK(handle_test("HandleVariableGPU", "HandleAssignGPU", "FakeGPU")); - TF_EXPECT_OK(handle_test("HandleVariableGPU", "TestHandleAssign", "FakeGPU")); - EXPECT_FALSE( - handle_test("HandleVariableGPU", "HandleAssignCPU", "FakeCPU").ok()); - EXPECT_FALSE( - handle_test("HandleVariableCPU", "HandleAssignGPU", "FakeCPU").ok()); -} - -// Test that an assignment of an operator to the wrong device -// is ignored when it could never be satisfied (due to reference -// edges, for example). -TEST_F(SimplePlacerTest, TestReferenceConnectionIgnoreInfeasible) { - Status s; - Graph g(OpRegistry::Global()); - { - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp( - "TestDevice", - b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); - Node* var = ops::SourceOp("TestVariable", - b.opts().WithName("var_0").WithDevice( - "/job:a/task:0/device:fakegpu:0")); - - // This op is specified on CPU, but in practice will be ignored, - // because the reference edges forces it on GPU. - ops::BinaryOp("TestAssign", var, input, - b.opts().WithName("assign").WithDevice( - "/job:a/task:0/device:fakecpu:0")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - s = Place(&g, &options); - TF_EXPECT_OK(s); - EXPECT_DEVICE_TYPE(g, "var_0", "FakeGPU"); - EXPECT_DEVICE_TYPE(g, "assign", "FakeGPU"); -} - -// Test that an assignment of an operator to the a more specified device -// causes the device to maintain its more specific placement. -TEST_F(SimplePlacerTest, - TestReferenceConnectionMoreSpecificDestinationSourceWins) { - Status s; - Graph g(OpRegistry::Global()); - { - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - // Input can be on either device - Node* input = - ops::SourceOp("TestCPUGPUOutput", - b.opts().WithName("in").WithDevice("/job:a/task:0")); - - // Variable can be on either device - Node* var = ops::SourceOp( - "TestVariable", b.opts().WithName("var_0").WithDevice("/job:a/task:0")); - - // This op is specified on CPU and is more specific than the variable. - // Because the variable is less specified, the variable will be - // assigned to CPU. - ops::BinaryOp("TestAssign", var, input, - b.opts().WithName("assign").WithDevice( - "/job:a/task:0/device:fakecpu:0")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - s = Place(&g, &options); - TF_EXPECT_OK(s); - EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU"); - EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU"); -} - -// A reference connection exists between a variable and an assign, -// where the assign has a device but the variable does not. In this -// case, the variable gets placed on the location of the assign -// operation. -TEST_F(SimplePlacerTest, TestReferenceConnectionNoSourceDevice) { - Status s; - Graph g(OpRegistry::Global()); - { - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp( - "TestDevice", - b.opts().WithName("in").WithDevice("/job:a/task:0/device:fakegpu:0")); - Node* var = ops::SourceOp("TestVariable", b.opts().WithName("var_0")); - ops::BinaryOp("TestAssign", var, input, - b.opts().WithName("assign").WithDevice( - "/job:a/task:0/device:fakecpu:0")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - s = Place(&g, &options); - TF_EXPECT_OK(s); - EXPECT_DEVICE_TYPE(g, "var_0", "FakeCPU"); - EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU"); -} - -TEST_F(SimplePlacerTest, TestColocationGroup) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - Node* colocated_with_input = ops::UnaryOp( - "TestRelu", input, - b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"})); - - // This will not be colocated with the input because TestInput is - // only availbale on CPU and TestRelu will default to GPU. - Node* not_colocated_with_input = - ops::UnaryOp("TestRelu", input, b.opts().WithName("foo")); - CHECK(colocated_with_input); - CHECK(not_colocated_with_input); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_COLOCATED(g, "in", "colocated_1"); - EXPECT_NOT_COLOCATED(g, "in", "foo"); -} - -TEST_F(SimplePlacerTest, TestMultipleColocationGroups) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - Node* colocated_with_input = ops::UnaryOp( - "TestRelu", input, - b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"})); - Node* colocated_with_input_and_other = - ops::UnaryOp("TestRelu", input, - b.opts().WithName("foo").WithAttr( - "_class", {"loc:@in", "loc:@colocated_1"})); - CHECK(colocated_with_input); - CHECK(colocated_with_input_and_other); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_COLOCATED(g, "in", "colocated_1"); - EXPECT_COLOCATED(g, "in", "foo"); -} - -TEST_F(SimplePlacerTest, TestInvalidMultipleColocationGroups) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - Node* colocated_with_input = ops::UnaryOp( - "ReluCPU", input, - b.opts().WithName("colocated_1").WithAttr("_class", {"loc:@in"})); - Node* colocated_with_input_and_other = - ops::UnaryOp("ReluGPU", input, - b.opts().WithName("foo").WithAttr( - "_class", {"loc:@in", "loc:@colocated_1"})); - CHECK(colocated_with_input); - CHECK(colocated_with_input_and_other); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - Status s = Place(&g); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Cannot colocate nodes 'foo' and 'in' because no " - "device type supports both of those nodes and the " - "other nodes colocated with them")); -} - -TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); - Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); - - // Two assigns (reference connections) with two different - // colocation groups. Because their colocation groups all map to the - // same device, this is a valid assignment. - ops::BinaryOp( - "TestAssign", var1, input, - b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"})); - ops::BinaryOp( - "TestAssign", var2, input, - b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"})); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_COLOCATED(g, "in", "var1"); - EXPECT_COLOCATED(g, "in", "var2"); - EXPECT_COLOCATED(g, "var1", "assign2"); - EXPECT_COLOCATED(g, "var2", "assign1"); -} - -TEST_F(SimplePlacerTest, - TestColocationGroupWithUnsatisfiableReferenceConnections) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - - Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); - Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); - // Var 3 is on GPU - Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3")); - - // Two assigns (reference connections) with two different - // colocation groups. Because their colocation groups all map to the - // same device, this is a valid assignment. - ops::BinaryOp( - "TestAssign", var1, input, - b.opts().WithName("assign1").WithAttr("_class", {"loc:@var1"})); - ops::BinaryOp( - "TestAssign", var2, input, - b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"})); - // Assign to var3, but try to use a colocation group that matches - // the assign of var2. This should fail because assign2 must be on CPU - // (it has a reference edge on var2), and assign3 must be on GPU, - // hence the conflict. - ops::BinaryOp( - "TestAssign", var3, input, - b.opts().WithName("assign3").WithAttr("_class", {"loc:@var2"})); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - Status s = Place(&g); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("Cannot colocate nodes 'var3' and 'assign3' because no " - "device type supports both of those nodes and the other " - "nodes colocated with them.")); -} - -TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - for (int i = 0; i < 10; ++i) { - // Declare ten variable and assignment pairs. - Node* var = ops::SourceOp("TestVariable", - b.opts().WithName(strings::StrCat("var_", i))); - ops::BinaryOp("TestAssign", var, input, - b.opts().WithName(strings::StrCat("assign_", i))); - } - for (int i = 10; i < 100; ++i) { - // Create a variable colocated with some existing variable, and - // an assignment colocated with a possibly-different variable. - Node* var = ops::SourceOp( - "TestVariable", - b.opts() - .WithName(strings::StrCat("var_", i)) - .WithAttr("_class", {strings::StrCat("loc:@var_", i % 6)})); - ops::BinaryOp( - "TestAssign", var, input, - b.opts() - .WithName(strings::StrCat("assign_", i)) - .WithAttr("_class", {strings::StrCat("loc:@assign_", i % 3)})); - } - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - for (int i = 0; i < 10; ++i) { - EXPECT_COLOCATED(g, strings::StrCat("var_", i), - strings::StrCat("assign_", i)); - } - for (int i = 10; i < 100; ++i) { - EXPECT_COLOCATED(g, strings::StrCat("var_", i), - strings::StrCat("assign_", i)); - EXPECT_COLOCATED(g, strings::StrCat("var_", i), - strings::StrCat("var_", i % 6)); - EXPECT_COLOCATED(g, strings::StrCat("assign_", i), - strings::StrCat("assign_", i % 3)); - } -} - -// Test that placement fails when no devices are registered. -TEST_F(SimplePlacerTest, TestEmptyDeviceSet) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - DeviceSet empty; - - Status s = Place(&g, &empty); - EXPECT_TRUE( - StringPiece(s.error_message()).contains("No devices are registered")); -} - -// Test that placement fails when the requested device forces an -// indirect constraint to be violated. -TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* in = ops::SourceOp("TestInput", b.opts().WithName("in")); - Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var")); - ops::BinaryOp("TestAssign", var, in, - b.opts().WithName("assign").WithDevice("/job:b/task:1")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - DeviceSet heterogeneous; - std::unique_ptr gpu( - FakeDevice::MakeGPU("/job:b/replica:0/task:0/device:fakegpu:0")); - heterogeneous.AddDevice(gpu.get()); - std::unique_ptr cpu( - FakeDevice::MakeCPU("/job:b/replica:0/task:1/device:fakecpu:0")); - heterogeneous.AddDevice(cpu.get()); - Status s = Place(&g, &heterogeneous); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("colocated with a group of nodes that required " - "incompatible device")); - - // The error message should contain information that indicates which - // op types have which registered device types. - EXPECT_TRUE(StringPiece(s.error_message()).contains("VariableGPU: FakeGPU")) - << s; - EXPECT_TRUE( - StringPiece(s.error_message()).contains("TestAssign: FakeGPU FakeCPU")) - << s; -} - -// Test that placement fails when an unknown device is requested. -TEST_F(SimplePlacerTest, TestUnknownDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - Status s = Place(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo")); -} - -// Test that placement fails when the combination of partial -// constraints leads to an unknown device. -TEST_F(SimplePlacerTest, TestUnknownMergedDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - Status s = Place(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/job:foo")); -} - -// Test that placement fails when the previously-assigned device for a -// node is unknown. -TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - GetNodeByName(g, "in")->set_assigned_device_name("/job:foo"); - - Status s = Place(&g); - EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("Assigned device '/job:foo' does not match any device")); -} - -// Test that placement fails when an op with no registered kernels is -// requested. -TEST_F(SimplePlacerTest, TestNoKernelsRegistered) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("VariableNoKernels", b.opts().WithName("var")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - Status s = Place(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains( - "No OpKernel was registered to support Op 'VariableNoKernels'")); - EXPECT_TRUE( - StringPiece(s.error_message()).contains("")); -} - -// Test that placement fails when a kernel is registered but no known -// device supports it. -TEST_F(SimplePlacerTest, TestNoDevicesRegistered) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("VariableGPU", b.opts().WithName("var")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - DeviceSet cpu_only; - std::unique_ptr cpu( - FakeDevice::MakeCPU("/job:a/replica:0/task:0/device:fakecpu:0")); - cpu_only.AddDevice(cpu.get()); - - Status s = Place(&g, &cpu_only); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("No OpKernel was registered to support " - "Op 'VariableGPU'")); - EXPECT_TRUE(StringPiece(s.error_message()).contains("device='FakeGPU'")); -} - -// Test that placement fails when a requested device is malformed. -TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/foo:bar")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - Status s = Place(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Malformed device specification '/foo:bar'")); -} - -// Test that placement fails when a previously-assigned device is malformed. -TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - GetNodeByName(g, "in")->set_assigned_device_name("/foo:bar"); - - Status s = Place(&g); - EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Malformed assigned device '/foo:bar'")); -} - -// Test that placement fails when a device was previously assigned to -// a node, but it does not uniquely identify a particular device. -TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestInput", b.opts().WithName("in")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - GetNodeByName(g, "in")->set_assigned_device_name("/job:a"); - - Status s = Place(&g); - EXPECT_EQ(error::INTERNAL, s.code()); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("Assigned device '/job:a' does not match any device")); -} - -// Test that ops request to be placed on non-existent devices will be relocated -// to existing device of the same type if allow_soft_placement is set. -TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestDevice", - b.opts().WithName("in").WithDevice("/device:fakegpu:11")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.set_allow_soft_placement(true); - TF_EXPECT_OK(Place(&g, &options)); - EXPECT_DEVICE_CONTAINS(g, "in", "/device:fakegpu:0"); -} - -// Test that ops request to be placed on non-existent devices will fail if -// allow_soft_placement is not set. -TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestDevice", - b.opts().WithName("in").WithDevice("/device:fakegpu:11")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - Status s = Place(&g, &options); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakegpu:11")); -} - -// Test that placement fails when a node requests an explicit device that is not -// supported by the registered kernels if allow_soft_placement is no set. -TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("VariableGPU", - b.opts().WithName("var").WithDevice("/device:fakecpu:0")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - Status s = Place(&g, &options); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()).contains("/device:fakecpu:0")); - EXPECT_TRUE( - StringPiece(s.error_message()) - .contains("no supported kernel for fakecpu devices is available")); -} - -// Test that placement fails when a node requests an explicit device that is not -// supported by the registered kernels if allow_soft_placement is no set. -TEST_F(SimplePlacerTest, TestNonExistentDevice) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("VariableGPU", - b.opts().WithName("var").WithDevice("/job:foo/replica:17")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - Status s = Place(&g, &options); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - LOG(WARNING) << s.error_message(); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("was explicitly assigned to /job:foo/replica:17 " - "but available devices")); -} - -TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("VariableGPU", - b.opts().WithName("var").WithDevice("/device:fakecpu:0")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.set_allow_soft_placement(true); - TF_EXPECT_OK(Place(&g, &options)); -} - -// Test that a graph with device type and reference constraints on -// some of the ops will successfully assign nodes to the constrained -// device, and colocate nodes with reference connections. -TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - // var_gpu has ref output and runs on GPU. - // force_gpu takes var_gpu and requested CPU. - // Verify that both are placed on GPU. - Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu")); - ops::UnaryOp( - "TestDeviceEnforce", var_gpu, - b.opts().WithName("force_gpu").WithDevice("/device:fakecpu:0")); - // var_cpu has ref output and runs on CPU. - // force_cpu takes var_cpu and requested GPU. - // Verify that both are placed on CPU. - Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); - ops::UnaryOp( - "TestDeviceEnforce", var_cpu, - b.opts().WithName("force_cpu").WithDevice("/device:fakegpu:0")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.set_allow_soft_placement(true); - TF_EXPECT_OK(Place(&g, &options)); - EXPECT_DEVICE_TYPE(g, "var_gpu", "FakeGPU"); - EXPECT_DEVICE_TYPE(g, "force_gpu", "FakeGPU"); - EXPECT_COLOCATED(g, "var_gpu", "force_gpu"); - EXPECT_DEVICE_TYPE(g, "var_cpu", "FakeCPU"); - EXPECT_DEVICE_TYPE(g, "force_cpu", "FakeCPU"); - EXPECT_COLOCATED(g, "var_cpu", "force_cpu"); -} - -// Test that placement fails when two nodes have a reference connection -// constraint, and each node requires a mutually incompatible device. -TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var")); - Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); - ops::BinaryOp("AssignCPU", var, input, b.opts().WithName("assign")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - Status s = Place(&g); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(StringPiece(s.error_message()) - .contains("Cannot colocate nodes 'var' and 'assign'")); -} - -// Test that a generator node follows its consumers (where there are several -// consumer nodes on the same devices). -TEST_F(SimplePlacerTest, TestGeneratorNodeFollowsConsumerNode) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - - // A variable is only on CPU - Node* var1_cpu = - ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu")); - Node* var2_cpu = - ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu")); - - // The constant to be assigned can be on both GPU or CPU. - // - // Because of the heuristic, it gets placed on CPU to avoid a - // copy. - Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in")); - - // The assigns are bound to CPU by the reference edge. - ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1")); - ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2")); - - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_COLOCATED(g, "var1_cpu", "in"); - EXPECT_COLOCATED(g, "assign1", "in"); - EXPECT_COLOCATED(g, "var2_cpu", "in"); - EXPECT_COLOCATED(g, "assign2", "in"); -} - -// Test that a generator node does not follow its consumers (where there are -// several consumers on different devices). -TEST_F(SimplePlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - - // A variable is only on CPU - Node* var1_cpu = - ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu")); - Node* var2_cpu = - ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu")); - - // The constant to be assigned can be on both GPU or CPU. - // - // Because of the heuristic, it ought to be on the GPU (cannot be - // co-located with both consumers, so goes to the 'standard' place) - Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in")); - - // The assigns are bound to CPU by the reference edge. - ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1")); - ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2")); - - TF_EXPECT_OK(BuildGraph(b, &g)); - - GetNodeByName(g, "var1_cpu") - ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:1"); - - GetNodeByName(g, "var2_cpu") - ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:2"); - } - - TF_EXPECT_OK(Place(&g)); - EXPECT_COLOCATED(g, "assign1", "var1_cpu"); - EXPECT_COLOCATED(g, "assign2", "var2_cpu"); - EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index caa60e58e6..bfdf967333 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -55,7 +55,7 @@ limitations under the License. namespace tensorflow { -// MasterSession wraps SimpleClientGraph in a reference counted object. +// MasterSession wraps ClientGraph in a reference counted object. // This way, MasterSession can clear up the cache mapping Run requests to // compiled graphs while the compiled graph is still being used. // @@ -63,10 +63,10 @@ namespace tensorflow { class MasterSession::ReffedClientGraph : public core::RefCounted { public: ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts, - std::unique_ptr cg, + std::unique_ptr cg, const SessionOptions& session_opts, const StatsPublisherFactory& stats_publisher_factory, - SimpleGraphExecutionState* execution_state, bool is_partial, + GraphExecutionState* execution_state, bool is_partial, WorkerCacheInterface* worker_cache) : session_handle_(handle), client_graph_(std::move(cg)), @@ -87,7 +87,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { ~ReffedClientGraph() override { DeregisterPartitions(); } - const SimpleClientGraph* client_graph() { return client_graph_.get(); } + const ClientGraph* client_graph() { return client_graph_.get(); } std::unique_ptr GetProfileHandler(uint64 step, int64 execution_count, @@ -186,7 +186,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // Checks that the requested fetches can be computed from the provided feeds. Status CheckFetches(const RunStepRequestWrapper& req, const RunState* run_state, - SimpleGraphExecutionState* execution_state); + GraphExecutionState* execution_state); string DetailText(const Node& node, const NodeExecStats& ns) { int64 tot = 0; @@ -203,7 +203,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { private: const string session_handle_; - const std::unique_ptr client_graph_; + const std::unique_ptr client_graph_; const SessionOptions session_opts_; const bool is_partial_; const DebugOptions& debug_opts_; @@ -813,7 +813,7 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats( // contention. Status MasterSession::ReffedClientGraph::CheckFetches( const RunStepRequestWrapper& req, const RunState* run_state, - SimpleGraphExecutionState* execution_state) { + GraphExecutionState* execution_state) { // Build the set of pending feeds that we haven't seen. std::unordered_set pending_feeds; for (const auto& input : run_state->pending_inputs) { @@ -1028,12 +1028,12 @@ Status MasterSession::Create(GraphDef* graph_def, session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); } - SimpleGraphExecutionStateOptions execution_options; + GraphExecutionStateOptions execution_options; execution_options.device_set = devices_.get(); execution_options.session_options = &session_opts_; { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( + TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph( graph_def, execution_options, &execution_state_)); } if (options.cluster_def != nullptr) { @@ -1142,7 +1142,7 @@ Status MasterSession::ListDevices(ListDevicesResponse* resp) const { Status MasterSession::Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp) { UpdateLastAccessTime(); - std::unique_ptr extended_execution_state; + std::unique_ptr extended_execution_state; { mutex_lock l(mu_); if (closed_) { @@ -1195,7 +1195,7 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, VLOG(1) << "Unseen hash " << hash << " for " << BuildGraphOptionsString(opts) << " is_partial = " << is_partial << "\n"; - std::unique_ptr client_graph; + std::unique_ptr client_graph; TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); WorkerCacheInterface* worker_cache = get_worker_cache(); auto entry = new ReffedClientGraph( diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 33b9bfe631..51ea92da68 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/debugger_state_interface.h" #include "tensorflow/core/common_runtime/device_set.h" -#include "tensorflow/core/common_runtime/simple_graph_execution_state.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h" #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_env.h" @@ -128,7 +128,7 @@ class MasterSession : public core::RefCounted { std::atomic partial_run_handle_counter_ = {0}; mutex mu_; - std::unique_ptr execution_state_ GUARDED_BY(mu_); + std::unique_ptr execution_state_ GUARDED_BY(mu_); int64 graph_version_; // We keep a map from a signature of a run request to the -- cgit v1.2.3