diff options
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 197 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 17 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session_test.cc | 50 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/simple_graph_execution_state.cc | 93 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/simple_graph_execution_state.h | 36 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/simple_placer.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/protobuf/config.proto | 9 | ||||
-rw-r--r-- | tensorflow/python/client/session.py | 6 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 42 |
10 files changed, 323 insertions, 130 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 044d732f30..bc97c96416 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1365,6 +1365,7 @@ tf_cc_test( ":test_main", ":testlib", "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:fifo_queue_op", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 0c3ddaf2f7..5f81e8c38c 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -41,7 +41,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_partition.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/refcount.h" @@ -178,6 +177,23 @@ DirectSession::~DirectSession() { if (options_.config.use_per_session_threads()) { delete thread_pool_; } + + execution_state_.reset(nullptr); + flib_def_.reset(nullptr); +} + +void DirectSession::MaybeInitializeExecutionState(const GraphDef& graph) { + // If already initialied, do nothing. + if (flib_def_ && execution_state_) { + return; + } + // Set up the per-session execution state. + flib_def_.reset(new FunctionLibraryDefinition(graph.library())); + SimpleGraphExecutionStateOptions options; + options.device_set = &device_set_; + options.session_options = &options_; + execution_state_.reset( + new SimpleGraphExecutionState(flib_def_.get(), options)); } Status DirectSession::Create(const GraphDef& graph) { @@ -195,57 +211,14 @@ Status DirectSession::Extend(const GraphDef& graph) { } Status DirectSession::ExtendLocked(const GraphDef& graph) { - // Merge versions - if (graph_def_.has_versions()) { - if (graph_def_.versions().producer() != graph.versions().producer()) { - return errors::InvalidArgument( - "Can't extend GraphDef at version ", graph_def_.versions().producer(), - " with graph at version ", graph.versions().producer()); - } - VersionDef* versions = graph_def_.mutable_versions(); - versions->set_min_consumer( - std::max(versions->min_consumer(), graph.versions().min_consumer())); - if (graph.versions().bad_consumers_size()) { - // Add new bad_consumers that aren't already marked bad. - // - // Note: This implementation is quadratic time if there are many calls to - // ExtendLocked with many bad consumers. Since this is unlikely, and - // fixing it would require data structures outside of this routine, - // quadratic time it is. - auto* bad_consumers = versions->mutable_bad_consumers(); - const std::unordered_set<int> existing(bad_consumers->begin(), - bad_consumers->end()); - for (const int v : graph.versions().bad_consumers()) { - if (existing.find(v) == existing.end()) { - bad_consumers->Add(v); - } - } - } - } else { - graph_def_.mutable_versions()->CopyFrom(graph.versions()); - } - - const int node_size_before_merge = graph_def_.node_size(); - graph_def_.MergeFrom(graph); + MaybeInitializeExecutionState(graph); + std::unique_ptr<SimpleGraphExecutionState> old_state; + SimpleGraphExecutionState* new_state = nullptr; + TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &new_state)); - FunctionLibraryDefinition fdefs(graph_def_.library()); - // Add default attributes to all new nodes in the graph. - Status s = - AddDefaultAttrsToGraphDef(&graph_def_, fdefs, node_size_before_merge); - if (!s.ok()) { - // One of the nodes was invalid, return the state of graph_def_ - // to what it was before this function. - const int nodes_added = graph_def_.node_size() - node_size_before_merge; - graph_def_.mutable_node()->DeleteSubrange(node_size_before_merge, - nodes_added); - return s; - } - - if (graph_def_.versions().producer() >= 5) { - // Validate the graph: we assume that merging two valid graphs - // should maintain graph validity. - TF_RETURN_IF_ERROR(graph::ValidateGraphDef(graph_def_, fdefs)); - } + // Swap out the old state. + old_state = std::move(execution_state_); + execution_state_.reset(new_state); graph_created_ = true; // In case this is first call return Status::OK(); @@ -680,16 +653,19 @@ Status DirectSession::GetOrCreateExecutors( } } + BuildGraphOptions options; + options.feed_endpoints = inputs_sorted; + options.fetch_endpoints = outputs_sorted; + options.target_nodes = tn_sorted; + // The executor_lock_ is intentionally released while executor is // being created. - FunctionLibraryDefinition* fdefs; std::unordered_map<string, Graph*> graphs; - Status s = CreateGraphs(inputs, outputs, target_nodes, &fdefs, &graphs, - run_state_args); + Status s = CreateGraphs(options, &graphs, run_state_args); TF_RETURN_IF_ERROR(s); std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); - ek->func_defs = fdefs; + ek->func_defs = flib_def_.get(); if (run_state_args->is_partial_run) { ek->graph = run_state_args->graph; ek->name_to_node = new NameNodeMap; @@ -724,9 +700,9 @@ Status DirectSession::GetOrCreateExecutors( ek->items.resize(ek->items.size() + 1); auto* item = &(ek->items.back()); - item->flib = - NewFunctionLibraryRuntime(device_mgr_.get(), device, runner, - graph_def_version, fdefs, optimizer_opts); + item->flib = NewFunctionLibraryRuntime(device_mgr_.get(), device, runner, + graph_def_version, flib_def_.get(), + optimizer_opts); LocalExecutorParams params; params.device = device; @@ -802,67 +778,67 @@ Status DirectSession::GetOrCreateExecutors( return Status::OK(); } -void DirectSession::SaveStatefulNodes(Graph* graph) { - for (Node* n : graph->nodes()) { - if (n->op_def().is_stateful()) { - VLOG(2) << "Saving " << n->DebugString(); - stateful_placements_[n->name()] = n->assigned_device_name(); +Status DirectSession::CreateGraphs(const BuildGraphOptions& options, + std::unordered_map<string, Graph*>* outputs, + RunStateArgs* run_state_args) { + std::unique_ptr<SimpleClientGraph> client_graph; + SimpleClientGraph* cgraph = nullptr; + + std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder; + SimpleGraphExecutionState* execution_state = nullptr; + if (options_.config.graph_options().place_pruned_graph()) { + // Because we are placing pruned graphs, we need to create a + // new SimpleGraphExecutorState for every new unseen graph, + // and then place it. + SimpleGraphExecutionStateOptions prune_options; + prune_options.device_set = &device_set_; + prune_options.session_options = &options_; + temp_exec_state_holder.reset( + new SimpleGraphExecutionState(flib_def_.get(), prune_options)); + { + mutex_lock l(mu_); + temp_exec_state_holder->SetStatefulPlacements(stateful_placements_); } + + TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend( + execution_state_->original_graph_def(), &execution_state)); + temp_exec_state_holder.reset(execution_state); + } else { + execution_state = execution_state_.get(); } -} -void DirectSession::RestoreStatefulNodes(Graph* graph) { - for (Node* n : graph->nodes()) { - if (n->op_def().is_stateful()) { - auto iter = stateful_placements_.find(n->name()); - if (iter != stateful_placements_.end()) { - n->set_assigned_device_name(iter->second); - VLOG(2) << "Restored " << n->DebugString(); + TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &cgraph)); + { + auto current_stateful_placements = execution_state->GetStatefulPlacements(); + mutex_lock l(mu_); + // Update our current state based on the execution_state's + // placements. If there are any mismatches for a node, + // we should fail, as this should never happen. + for (auto placement_pair : current_stateful_placements) { + const string& node_name = placement_pair.first; + const string& placement = placement_pair.second; + auto iter = stateful_placements_.find(node_name); + if (iter == stateful_placements_.end()) { + stateful_placements_.insert(std::make_pair(node_name, placement)); + } else if (iter->second != placement) { + return errors::Internal( + "Stateful placement mismatch. " + "Current assignment of ", + node_name, " to ", iter->second, " does not match ", placement); } } - } -} -Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, - gtl::ArraySlice<string> fetches, - gtl::ArraySlice<string> target_nodes, - FunctionLibraryDefinition** func_defs, - std::unordered_map<string, Graph*>* outputs, - RunStateArgs* run_state_args) { - std::unique_ptr<FunctionLibraryDefinition> fdefs; - std::unique_ptr<Graph> graph; - { - mutex_lock l(graph_def_lock_); - fdefs.reset(new FunctionLibraryDefinition(graph_def_.library())); - graph.reset(new Graph(fdefs.get())); - GraphConstructorOptions opts; - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, graph.get())); + stateful_placements_ = execution_state->GetStatefulPlacements(); } + client_graph.reset(cgraph); // Remember the graph in run state if this is a partial run. if (run_state_args->is_partial_run) { - run_state_args->graph = new Graph(fdefs.get()); - CopyGraph(*graph.get(), run_state_args->graph); - } - - TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( - graph.get(), feeds, fetches, target_nodes, - device_set_.client_device()->attributes())); - - // Run the simple placer after rewriting the graph. - SimplePlacer placer(graph.get(), &device_set_, &options_); - - { - mutex_lock l(mu_); - // Restore stateful nodes. - RestoreStatefulNodes(graph.get()); - TF_RETURN_IF_ERROR(placer.Run()); - // Save stateful nodes. - SaveStatefulNodes(graph.get()); + run_state_args->graph = new Graph(flib_def_.get()); + CopyGraph(*execution_state->full_graph(), run_state_args->graph); } // Partition the graph across devices. - std::unordered_map<string, GraphDef> partitions; PartitionOptions popts; popts.node_to_loc = [](const Node* node) { return node->assigned_device_name(); @@ -877,7 +853,9 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, return 1; }; popts.control_flow_added = false; - TF_RETURN_IF_ERROR(Partition(popts, graph.get(), &partitions)); + + std::unordered_map<string, GraphDef> partitions; + TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions)); std::vector<string> device_names; for (auto device : devices_) { @@ -917,12 +895,12 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, // may be possible use cases where a device may want to modify // function definitions - in which case the library would need to be // replicated per device. - s = d->MaybeRewriteGraph(graph_def_.library(), graph_def); + s = d->MaybeRewriteGraph(flib_def_->ToProto(), graph_def); if (!s.ok()) { break; } } - Graph* device_graph = new Graph(fdefs.get()); + Graph* device_graph = new Graph(flib_def_.get()); GraphConstructorOptions device_opts; // There are internal operations (e.g., send/recv) that we now // allow. @@ -940,7 +918,6 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, gtl::STLDeleteValues(outputs); return s; } - *func_defs = fdefs.release(); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index d64460bd1f..6a531c4ec8 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/common_runtime/simple_graph_execution_state.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/session_state.h" @@ -119,7 +120,6 @@ class DirectSession : public Session { delete item.executor; delete item.flib; } - delete func_defs; delete graph; delete name_to_node; } @@ -159,6 +159,10 @@ class DirectSession : public Session { Graph* graph = nullptr; }; + // Initializes the base execution state given the 'graph', + // if not already initialized. + void MaybeInitializeExecutionState(const GraphDef& graph); + // Retrieves an already existing set of executors to run 'inputs' and // 'outputs', or creates and caches them for future use. ::tensorflow::Status GetOrCreateExecutors( @@ -168,10 +172,7 @@ class DirectSession : public Session { // Creates several graphs given the existing graph_def_ and the // input feeds and fetches, given 'devices'. - ::tensorflow::Status CreateGraphs(gtl::ArraySlice<string> feeds, - gtl::ArraySlice<string> fetches, - gtl::ArraySlice<string> target_nodes, - FunctionLibraryDefinition** func_defs, + ::tensorflow::Status CreateGraphs(const BuildGraphOptions& options, std::unordered_map<string, Graph*>* outputs, RunStateArgs* run_state_args); @@ -239,14 +240,16 @@ class DirectSession : public Session { // Saves and restores device placements for stateful nodes. mutex mu_; - void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); - void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Map of placed stateful nodes, i.e. nodes for which is_stateful() // is true, such as "params" and "queue" nodes. Once placed these // nodes can not be moved to a different device. Maps node names to // device names. std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_); + // Execution_state; used when placing the entire graph. + std::unique_ptr<SimpleGraphExecutionState> execution_state_; + std::unique_ptr<FunctionLibraryDefinition> flib_def_; + // For generating unique names. int64 name_counter_ GUARDED_BY(mu_) = 0; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 75a1235f0b..7347f255d8 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -240,7 +240,9 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) { test::graph::ToGraphDef(&graph, &def); - std::unique_ptr<Session> session(CreateSession()); + SessionOptions options; + (*options.config.mutable_device_count())["CPU"] = 2; + std::unique_ptr<Session> session(NewSession(options)); ASSERT_TRUE(session != nullptr); TF_ASSERT_OK(session->Create(def)); std::vector<std::pair<string, Tensor>> inputs; @@ -254,7 +256,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) { def.Clear(); y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1"); test::graph::ToGraphDef(&graph, &def); - session.reset(CreateSession()); + session.reset(NewSession(options)); TF_ASSERT_OK(session->Create(def)); TF_ASSERT_OK(session->Run(inputs, output_names, {}, &outputs)); } @@ -431,6 +433,50 @@ TEST(DirectSessionTest, DarthKernel) { delete sess; } +// Have the Darth op in the graph placed on GPU, but don't run it. +TEST(DirectSessionTest, PlacePrunedGraph) { + { + Graph g(OpRegistry::Global()); + Tensor vx(DT_FLOAT, TensorShape({})); + vx.scalar<float>()() = 1.0; + Node* x = test::graph::Constant(&g, vx); + Node* y = test::graph::Unary(&g, "Darth", x); + y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); + GraphDef def; + test::graph::ToGraphDef(&g, &def); + + // By default, we place the entire graph, so we should fail the + // call to Run, even if we don't run the bad op. + SessionOptions options; + std::unique_ptr<Session> sess(NewSession(options)); + TF_ASSERT_OK(sess->Create(def)); + std::vector<Tensor> outputs; + auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs); + EXPECT_TRUE(errors::IsInvalidArgument(s)); + } + + { + Graph g(OpRegistry::Global()); + Tensor vx(DT_FLOAT, TensorShape({})); + vx.scalar<float>()() = 1.0; + Node* x = test::graph::Constant(&g, vx); + Node* y = test::graph::Unary(&g, "Darth", x); + y->set_assigned_device_name("/job:localhost/replica:0/task:0/gpu:0"); + GraphDef def; + test::graph::ToGraphDef(&g, &def); + + SessionOptions options; + // Set the option to place pruned graphs, we should expect this + // to run. + options.config.mutable_graph_options()->set_place_pruned_graph(true); + std::unique_ptr<Session> sess(NewSession(options)); + TF_ASSERT_OK(sess->Create(def)); + std::vector<Tensor> outputs; + auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs); + TF_EXPECT_OK(s); + } +} + TEST(DirectSessionTest, PartialRunTest) { GraphDef def; Graph g(OpRegistry::Global()); diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 8fae2634de..6ba32814ac 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -87,13 +88,52 @@ Status SimpleGraphExecutionState::Extend( } } + // 3. Merge the versions field. int old_node_size = gdef.node_size(); gdef.mutable_node()->MergeFrom(extension_def.node()); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *ops_, old_node_size)); + // Merge versions + if (gdef.has_versions()) { + if (gdef.versions().producer() != extension_def.versions().producer()) { + return errors::InvalidArgument( + "Can't extend GraphDef at version ", gdef.versions().producer(), + " with graph at version ", extension_def.versions().producer()); + } + VersionDef* versions = gdef.mutable_versions(); + versions->set_min_consumer(std::max( + versions->min_consumer(), extension_def.versions().min_consumer())); + if (extension_def.versions().bad_consumers_size()) { + // Add new bad_consumers that aren't already marked bad. + // + // Note: This implementation is quadratic time if there are many calls to + // ExtendLocked with many bad consumers. Since this is unlikely, and + // fixing it would require data structures outside of this routine, + // quadratic time it is. + auto* bad_consumers = versions->mutable_bad_consumers(); + const std::unordered_set<int> existing(bad_consumers->begin(), + bad_consumers->end()); + for (const int v : extension_def.versions().bad_consumers()) { + if (existing.find(v) == existing.end()) { + bad_consumers->Add(v); + } + } + } - // 3. Add the extension. + } else { + gdef.mutable_versions()->CopyFrom(extension_def.versions()); + } + + // 5. Validate that the final graphdef is valid. + if (gdef.versions().producer() >= 5) { + // Validate the graph: we assume that merging two valid graphs + // should maintain graph validity. + TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *ops_)); + } + + // 6. Add the extension. SimpleGraphExecutionStateOptions combined_options; combined_options.device_set = device_set_; + combined_options.session_options = session_options_; SimpleGraphExecutionState* new_execution_state = new SimpleGraphExecutionState(ops_, combined_options); @@ -102,6 +142,7 @@ Status SimpleGraphExecutionState::Extend( delete new_execution_state; return new_execution_state_status; } + new_execution_state->SetStatefulPlacements(GetStatefulPlacements()); *out = new_execution_state; // TODO(mrry): This is likely to be used for non-throughput-sensitive @@ -110,14 +151,47 @@ Status SimpleGraphExecutionState::Extend( return Status::OK(); } -Status SimpleGraphExecutionState::InitBaseGraph() { +void SimpleGraphExecutionState::SaveStatefulNodes(Graph* graph) { + for (Node* n : graph->nodes()) { + if (n->op_def().is_stateful()) { + VLOG(2) << "Saving " << n->DebugString(); + stateful_placements_[n->name()] = n->assigned_device_name(); + } + } +} + +void SimpleGraphExecutionState::RestoreStatefulNodes(Graph* graph) { + for (Node* n : graph->nodes()) { + if (n->op_def().is_stateful()) { + auto iter = stateful_placements_.find(n->name()); + if (iter != stateful_placements_.end()) { + n->set_assigned_device_name(iter->second); + VLOG(2) << "Restored " << n->DebugString(); + } + } + } +} + +Status SimpleGraphExecutionState::InitBaseGraph( + const BuildGraphOptions& options) { std::unique_ptr<Graph> new_graph(new Graph(ops_)); GraphConstructorOptions opts; TF_RETURN_IF_ERROR( ConvertGraphDefToGraph(opts, original_graph_def_, new_graph.get())); + if (session_options_ && + session_options_->config.graph_options().place_pruned_graph()) { + // Rewrite the graph before placement. + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( + new_graph.get(), options.feed_endpoints, options.fetch_endpoints, + options.target_nodes, device_set_->client_device()->attributes())); + } + + // Save stateful placements before placing. + RestoreStatefulNodes(new_graph.get()); SimplePlacer placer(new_graph.get(), device_set_, session_options_); // TODO(mrry): Consider making the SimplePlacer cancelable. TF_RETURN_IF_ERROR(placer.Run()); + SaveStatefulNodes(new_graph.get()); graph_ = new_graph.release(); return Status::OK(); } @@ -128,17 +202,20 @@ Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options, mutex_lock l(mu_); // Lazily initialize the base graph. if (!graph_) { - TF_RETURN_IF_ERROR(InitBaseGraph()); + TF_RETURN_IF_ERROR(InitBaseGraph(options)); } std::unique_ptr<SimpleClientGraph> cgraph(new SimpleClientGraph(ops_)); CopyGraph(*graph_, &cgraph->graph); - // Extract the subset of the graph that needs to be run, adding feed/fetch - // ops as needed. - TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( - &cgraph->graph, options.feed_endpoints, options.fetch_endpoints, - options.target_nodes, device_set_->client_device()->attributes())); + if (session_options_ == nullptr || + !session_options_->config.graph_options().place_pruned_graph()) { + // Extract the subset of the graph that needs to be run, adding feed/fetch + // ops as needed. + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( + &cgraph->graph, options.feed_endpoints, options.fetch_endpoints, + options.target_nodes, device_set_->client_device()->attributes())); + } // Copy the extracted graph in order to make its node ids dense, // since the local CostModel used to record its stats is sized by diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h index bf3a85b5a2..02357c8037 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h @@ -114,10 +114,44 @@ class SimpleGraphExecutionState { // execution, e.g. a send, recv or feed node. Status GlobalNodeDefByName(const string& name, NodeDef* out); + // The graph returned by BuildGraph may contain only the pruned + // graph, whereas some clients may want access to the full graph. + const Graph* full_graph() { + mutex_lock l(mu_); + return graph_; + } + + // Returns a reference to the current graph_def. Use must + // not extend beyond lifetime of SimpleGrahExecutionState object. + const GraphDef& original_graph_def() { return original_graph_def_; } + + // Returns the map of stateful placements as a map of + // node name to placement string. + std::unordered_map<string, string> GetStatefulPlacements() const { + mutex_lock l(mu_); + return stateful_placements_; + } + + // Restores the map of stateful placements as a map of + // node name to placement string. + void SetStatefulPlacements(const std::unordered_map<string, string>& sp) { + mutex_lock l(mu_); + stateful_placements_ = sp; + } + private: mutable mutex mu_; - Status InitBaseGraph() EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status InitBaseGraph(const BuildGraphOptions& options) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Map of placed stateful nodes, i.e. nodes for which is_stateful() + // is true, such as "params" and "queue" nodes. Once placed these + // nodes can not be moved to a different device. Maps node names to + // device names. + std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_); + void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); + void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); const OpRegistryInterface* const ops_; // Not owned GraphDef original_graph_def_; // Immutable after ctor. diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index 19a44c6a98..c14c37a4c3 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -469,7 +469,6 @@ class ColocationGraph { // If the NodeDef contains a device, then we interpret it as a // (partial) device specification. - string colocated_node_name; if (!node.def().device().empty()) { // The user has specified a device in the NodeDef, try to find a // valid device matching their specification in the set of @@ -659,7 +658,6 @@ Status SimplePlacer::Run() { // and we can experiment with other algorithms when given a choice of // devices. node->set_assigned_device_name(devices[0]->name()); - // Log placement if log_device_placement is set. if (options_ && options_->config.log_device_placement()) { printf("%s: %s\n", node->name().c_str(), diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 3bec167206..fe926a184d 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -80,6 +80,15 @@ message GraphOptions { // Annotate each Node with Op output shape data, to the extent it can // be statically inferred. bool infer_shapes = 5; + + // Only place the subgraphs that are run, rather than the entire graph. + // + // This is useful for interactive graph building, where one might + // produce graphs that cannot be placed during the debugging + // process. In particular, it allows the client to continue work in + // a session after adding a node to a graph whose placement + // constraints are unsatisfiable. + bool place_pruned_graph = 6; }; // Session configuration parameters. diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index b0c82b8874..fbc48cdbba 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -24,6 +24,7 @@ import threading import numpy as np +from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -897,6 +898,11 @@ class InteractiveSession(BaseSession): graph: (Optional.) The `Graph` to be launched (described above). config: (Optional) `ConfigProto` proto used to configure the session. """ + if not config: + config = config_pb2.ConfigProto() + # Interactive sessions always place pruned graphs. + config.graph_options.place_pruned_graph = True + super(InteractiveSession, self).__init__(target, graph, config) self._default_session = self.as_default() self._default_session.__enter__() diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index d60883464d..09698d2c78 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -707,6 +707,48 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual([[24.0]], e.eval()) sess.close() + def testInteractivePlacePrunedGraph(self): + sess = session.InteractiveSession() + + # Build a graph that has a bad op in it (no kernel). + # + # This test currently does not link in any GPU kernels, + # which is why placing this is invalid. If at some point + # GPU kernels are added to this test, some other different + # op / device combo should be chosen. + with ops.device('/gpu:0'): + a = constant_op.constant(1.0, shape=[1, 2]) + + b = constant_op.constant(1.0, shape=[1, 2]) + + # Only run the valid op, this should work. + b.eval() + + with self.assertRaises(errors.InvalidArgumentError): + a.eval() + sess.close() + + def testDefaultSessionPlacePrunedGraph(self): + sess = session.Session() + + # Build a graph that has a bad op in it (no kernel). + # + # This test currently does not link in any GPU kernels, + # which is why placing this is invalid. If at some point + # GPU kernels are added to this test, some other different + # op / device combo should be chosen. + with ops.device('/gpu:0'): + _ = constant_op.constant(1.0, shape=[1, 2]) + + b = constant_op.constant(1.0, shape=[1, 2]) + + with self.assertRaises(errors.InvalidArgumentError): + # Even though we don't run the bad op, we place the entire + # graph, which should fail with a non-interactive session. + sess.run(b) + + sess.close() + def testSharedGraph(self): with ops.Graph().as_default() as g, ops.device('/cpu:0'): a = constant_op.constant(1.0, shape=[1, 2]) |