diff options
9 files changed, 200 insertions, 134 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 52a810d76b..848a71c474 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -285,28 +285,47 @@ DirectSession::~DirectSession() { flib_def_.reset(nullptr); } -void DirectSession::MaybeInitializeExecutionState(const GraphDef& graph) { +Status DirectSession::MaybeInitializeExecutionState( + const GraphDef& graph, bool* out_already_initialized) { // If already initialized, do nothing. if (flib_def_ && execution_state_) { - return; + *out_already_initialized = true; + return Status::OK(); } // Set up the per-session execution state. + // NOTE(mrry): The function library created here will be used for + // all subsequent extensions of the graph. flib_def_.reset( new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); SimpleGraphExecutionStateOptions options; options.device_set = &device_set_; options.session_options = &options_; - execution_state_.reset( - new SimpleGraphExecutionState(graph.library(), options)); + // TODO(mrry,suharshs): We explicitly copy `graph` so that + // `MakeForBaseGraph()` can take ownership of its + // contents. Previously this happened implicitly in calls to the + // `SimpleGraphExecutionState`. Other sessions call + // `MakeForBaseGraph` in such a way that we can destructively read + // the passed-in `GraphDef`. In principle we could do the same here, + // with a wider refactoring; we might revise the direct session so + // that it copies the graph fewer times. + GraphDef temp(graph); + TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( + &temp, options, &execution_state_)); + graph_created_ = true; + *out_already_initialized = false; + return Status::OK(); } Status DirectSession::Create(const GraphDef& graph) { - mutex_lock l(graph_def_lock_); - if (graph_created_) { - return errors::AlreadyExists( - "A Graph has already been created for this session."); + if (graph.node_size() > 0) { + mutex_lock l(graph_def_lock_); + if (graph_created_) { + return errors::AlreadyExists( + "A Graph has already been created for this session."); + } + return ExtendLocked(graph); } - return ExtendLocked(graph); + return Status::OK(); } Status DirectSession::Extend(const GraphDef& graph) { @@ -316,12 +335,16 @@ Status DirectSession::Extend(const GraphDef& graph) { } Status DirectSession::ExtendLocked(const GraphDef& graph) { - MaybeInitializeExecutionState(graph); - std::unique_ptr<SimpleGraphExecutionState> state; - TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); - execution_state_.swap(state); - - graph_created_ = true; // In case this is first call + bool already_initialized; + // If this is the first call, we can initialize the execution state + // with `graph` and do not need to call `Extend()`. + TF_RETURN_IF_ERROR( + MaybeInitializeExecutionState(graph, &already_initialized)); + if (already_initialized) { + std::unique_ptr<SimpleGraphExecutionState> state; + TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); + execution_state_.swap(state); + } return Status::OK(); } @@ -949,7 +972,7 @@ Status DirectSession::GetOrCreateExecutors( } Status DirectSession::CreateGraphs( - const BuildGraphOptions& options, + const BuildGraphOptions& subgraph_options, std::unordered_map<string, std::unique_ptr<Graph>>* outputs, std::unique_ptr<FunctionLibraryDefinition>* flib_def, RunStateArgs* run_state_args) { @@ -960,23 +983,23 @@ Status DirectSession::CreateGraphs( SimpleGraphExecutionState* execution_state = nullptr; if (options_.config.graph_options().place_pruned_graph()) { // Because we are placing pruned graphs, we need to create a - // new SimpleGraphExecutorState for every new unseen graph, + // new SimpleGraphExecutionState for every new unseen graph, // and then place it. SimpleGraphExecutionStateOptions prune_options; prune_options.device_set = &device_set_; prune_options.session_options = &options_; - temp_exec_state_holder.reset(new SimpleGraphExecutionState( - execution_state_->original_graph_def().library(), prune_options)); - temp_exec_state_holder->SetStatefulPlacements(stateful_placements_); - - TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend( - execution_state_->original_graph_def(), &temp_exec_state_holder)); + prune_options.stateful_placements = stateful_placements_; + TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph( + execution_state_->original_graph_def().library(), prune_options, + execution_state_->original_graph_def(), subgraph_options, + &temp_exec_state_holder, &client_graph)); execution_state = temp_exec_state_holder.get(); } else { execution_state = execution_state_.get(); + TF_RETURN_IF_ERROR( + execution_state->BuildGraph(subgraph_options, &client_graph)); } - TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &client_graph)); auto current_stateful_placements = execution_state->GetStatefulPlacements(); // Update our current state based on the execution_state's // placements. If there are any mismatches for a node, diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 8e757e9273..a428911253 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -164,7 +164,8 @@ class DirectSession : public Session { // Initializes the base execution state given the 'graph', // if not already initialized. - void MaybeInitializeExecutionState(const GraphDef& graph) + Status MaybeInitializeExecutionState(const GraphDef& graph, + bool* out_already_initialized) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); // Retrieves an already existing set of executors to run 'inputs' and diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 4eb48c7bcf..38dd627da0 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -244,13 +244,8 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) { (*options.config.mutable_device_count())["CPU"] = 2; std::unique_ptr<Session> session(NewSession(options)); ASSERT_TRUE(session != nullptr); - TF_ASSERT_OK(session->Create(def)); - std::vector<std::pair<string, Tensor>> inputs; - std::vector<string> output_names = {y->name() + ":0"}; - std::vector<Tensor> outputs; - // Should return an error. - ASSERT_FALSE(session->Run(inputs, output_names, {}, &outputs).ok()); + ASSERT_FALSE(session->Create(def).ok()); // Fix placement and run again def.Clear(); @@ -258,7 +253,8 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) { test::graph::ToGraphDef(&graph, &def); session.reset(NewSession(options)); TF_ASSERT_OK(session->Create(def)); - TF_ASSERT_OK(session->Run(inputs, output_names, {}, &outputs)); + std::vector<Tensor> outputs; + TF_ASSERT_OK(session->Run({}, {y->name() + ":0"}, {}, &outputs)); } TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) { @@ -454,12 +450,10 @@ TEST(DirectSessionTest, PlacePrunedGraph) { test::graph::ToGraphDef(&g, &def); // By default, we place the entire graph, so we should fail the - // call to Run, even if we don't run the bad op. + // call to Create. SessionOptions options; std::unique_ptr<Session> sess(NewSession(options)); - TF_ASSERT_OK(sess->Create(def)); - std::vector<Tensor> outputs; - auto s = sess->Run({}, {x->name() + ":0"}, {}, &outputs); + auto s = sess->Create(def); EXPECT_TRUE(errors::IsInvalidArgument(s)); } diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 365120dd0e..2be12e7ad8 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -40,33 +40,68 @@ limitations under the License. namespace tensorflow { SimpleGraphExecutionState::SimpleGraphExecutionState( - const FunctionDefLibrary& func_def_lib, - const SimpleGraphExecutionStateOptions& options) - : device_set_(options.device_set), + GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options) + : stateful_placements_(options.stateful_placements), + device_set_(options.device_set), session_options_(options.session_options), costs_(true /*is_global*/), - flib_def_( - new FunctionLibraryDefinition(OpRegistry::Global(), func_def_lib)), + flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(), + graph_def->library())), graph_(nullptr) { + // NOTE(mrry): GraphDef does not have a move constructor, so we pass + // a non-const pointer and use `Swap()` to transfer the contents + // without copying. + original_graph_def_.Swap(graph_def); // TODO(mrry): Publish placement visualizations or handle the log // placement option. } SimpleGraphExecutionState::~SimpleGraphExecutionState() { - mutex_lock l(mu_); node_name_to_cost_id_map_.clear(); delete graph_; } -Status SimpleGraphExecutionState::Create(GraphDef* graph_def) { - if (original_graph_def_.node_size() > 0) { - return errors::InvalidArgument( - "Cannot call Create on SimpleGraphExecutionState twice"); +/* static */ Status SimpleGraphExecutionState::MakeForBaseGraph( + GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options, + std::unique_ptr<SimpleGraphExecutionState>* out_state) { + std::unique_ptr<SimpleGraphExecutionState> ret( + new SimpleGraphExecutionState(graph_def, options)); + + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_, + *ret->flib_def_.get(), 0)); + // TODO(mrry): Refactor InitBaseGraph() so that we don't have to + // pass an empty BuildGraphOptions (that isn't going to be used when + // place_pruned_graph is false). + if (!ret->session_options_->config.graph_options().place_pruned_graph()) { + TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions())); } + *out_state = std::move(ret); + return Status::OK(); +} - original_graph_def_.Swap(graph_def); - VLOG(2) << "Incoming def: " << ProtoDebugString(original_graph_def_); - return AddDefaultAttrsToGraphDef(&original_graph_def_, *flib_def_.get(), 0); +/* static */ Status SimpleGraphExecutionState::MakeForPrunedGraph( + const FunctionDefLibrary& func_def_lib, + const SimpleGraphExecutionStateOptions& options, const GraphDef& graph_def, + const BuildGraphOptions& subgraph_options, + std::unique_ptr<SimpleGraphExecutionState>* out_state, + std::unique_ptr<SimpleClientGraph>* out_client_graph) { + DCHECK(options.session_options->config.graph_options().place_pruned_graph()); + // NOTE(mrry): This makes a copy of `graph_def`, which is + // regrettable. We could make `GraphDef` objects sharable between + // execution states to optimize pruned graph execution, but since + // this case is primarily used for interactive sessions, we make the + // bet that graph construction is not performance-critical. (Note + // also that the previous version used `Extend()`, which is strictly + // more expensive than copying a `GraphDef`.) + GraphDef temp(graph_def); + std::unique_ptr<SimpleGraphExecutionState> ret( + new SimpleGraphExecutionState(&temp, options)); + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&ret->original_graph_def_, + *ret->flib_def_.get(), 0)); + TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options)); + TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph)); + *out_state = std::move(ret); + return Status::OK(); } Status SimpleGraphExecutionState::Extend( @@ -129,6 +164,12 @@ Status SimpleGraphExecutionState::Extend( gdef.mutable_versions()->CopyFrom(extension_def.versions()); } + // 4. Copy the function library from this execution state. + // NOTE(mrry): To match the previous behavior, the first GraphDef + // passed to a session will contain the function library that is + // used for all subsequent execution states. + *gdef.mutable_library() = flib_def_->ToProto(); + // 5. Validate that the final graphdef is valid. if (gdef.versions().producer() >= 5) { // Validate the graph: we assume that merging two valid graphs @@ -140,11 +181,21 @@ Status SimpleGraphExecutionState::Extend( SimpleGraphExecutionStateOptions combined_options; combined_options.device_set = device_set_; combined_options.session_options = session_options_; + combined_options.stateful_placements = stateful_placements_; + // NOTE(mrry): `gdef` is no longer valid after the constructor + // executes. std::unique_ptr<SimpleGraphExecutionState> new_execution_state( - new SimpleGraphExecutionState(flib_def_->ToProto(), combined_options)); - TF_RETURN_IF_ERROR(new_execution_state->Create(&gdef)); - new_execution_state->SetStatefulPlacements(GetStatefulPlacements()); + new SimpleGraphExecutionState(&gdef, combined_options)); + + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( + &new_execution_state->original_graph_def_, *flib_def_.get(), 0)); + if (!session_options_->config.graph_options().place_pruned_graph()) { + // TODO(mrry): Refactor InitBaseGraph() so that we don't have to + // pass an empty BuildGraphOptions (that isn't going to be used + // when place_pruned_graph is false). + TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions())); + } *out = std::move(new_execution_state); // TODO(mrry): This is likely to be used for non-throughput-sensitive @@ -196,8 +247,11 @@ Status SimpleGraphExecutionState::InitBaseGraph( RestoreStatefulNodes(new_graph.get()); CostModel costs(true /*is_global*/); - costs_.InitFromGraph(*new_graph.get()); - costs.MergeFromGlobal(costs_); + { + mutex_lock l(mu_); + costs_.InitFromGraph(*new_graph.get()); + costs.MergeFromGlobal(costs_); + } GraphOptimizationPassOptions optimization_options; optimization_options.session_options = session_options_; @@ -230,30 +284,9 @@ void SimpleGraphExecutionState::MergeCostsFromGlobal(CostModel* costs) { costs->MergeFromGlobal(costs_); } -Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name, - NodeDef* out) { - NodeNameToCostIdMap::const_iterator iter = - node_name_to_cost_id_map_.find(name); - if (iter != node_name_to_cost_id_map_.end()) { - mutex_lock l(mu_); // could use reader lock - const Node* node = graph_->FindNodeId(iter->second); - if (node) { - *out = node->def(); - return Status::OK(); - } - } - return errors::NotFound("Node name: ", name); -} - Status SimpleGraphExecutionState::BuildGraph( const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) { VLOG(1) << "BuildGraph"; - mutex_lock l(mu_); - // Lazily initialize the base graph. - if (!graph_) { - TF_RETURN_IF_ERROR(InitBaseGraph(options)); - } - std::unique_ptr<Graph> ng(new Graph(flib_def_.get())); CopyGraph(*graph_, ng.get()); diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h index 5fe16f0f42..2a33d9e298 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h @@ -42,6 +42,9 @@ class Timeline; struct SimpleGraphExecutionStateOptions { const DeviceSet* device_set = nullptr; const SessionOptions* session_options = nullptr; + // A map from node name to device name, representing the unchangeable + // placement of stateful nodes. + std::unordered_map<string, string> stateful_placements; }; // A SimpleClientGraph is simply a sub-graph of the full graph as induced by @@ -82,15 +85,29 @@ struct SimpleClientGraph { class SimpleGraphExecutionState { public: - SimpleGraphExecutionState(const FunctionDefLibrary& func_def_lib, - const SimpleGraphExecutionStateOptions& options); - virtual ~SimpleGraphExecutionState(); - // Initializes the SimpleGraphExecutionState with 'graph_def'. Can only be - // called once on an original SimpleGraphExecutionState. Callee may modify - // 'graph_def'. - Status Create(GraphDef* graph_def); + // Creates a new `SimpleGraphExecutionState` for the given + // `graph_def`, which represents the entire graph for a session. + // + // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def` + // in an undefined state. If it is necessary to use `*graph_def` + // after this call, make an explicit copy of the graph before + // calling this method. + static Status MakeForBaseGraph( + GraphDef* graph_def, const SimpleGraphExecutionStateOptions& options, + std::unique_ptr<SimpleGraphExecutionState>* out_state); + + // Creates a new `SimpleGraphExecutionState` and `SimpleClientGraph` + // for the subgraph of `original_graph_def` defined by + // `subgraph_options`. + static Status MakeForPrunedGraph( + const FunctionDefLibrary& func_def_lib, + const SimpleGraphExecutionStateOptions& options, + const GraphDef& original_graph_def, + const BuildGraphOptions& subgraph_options, + std::unique_ptr<SimpleGraphExecutionState>* out_state, + std::unique_ptr<SimpleClientGraph>* out_client_graph); // Creates a new SimpleGraphExecutionState representing the // concatenation of this graph, and the graph defined by @@ -100,6 +117,9 @@ class SimpleGraphExecutionState { // If successful, returns OK and the caller takes ownership of "*out". // Otherwise returns an error and does not modify "*out". // + // After calling `old_state->Extend()`, `old_state` may no longer be + // used. + // // NOTE(mrry): This method respects the placement of stateful nodes in // in *this, but currently does not transfer any other placement // or cost model information to the new graph. @@ -113,12 +133,6 @@ class SimpleGraphExecutionState { Status BuildGraph(const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out); - // Returns OK if the named node is found in the placed full graph owned - // by this execution_state, and sets *out to the NodeDef for that node. - // It may not exist if name is of a Node added for a particular subgraph - // execution, e.g. a send, recv or feed node. - Status GlobalNodeDefByName(const string& name, NodeDef* out); - // Sums execution statistics in "ss" into the CostModel. void UpdateCostsFromStats(const StepStats& ss); @@ -138,10 +152,20 @@ class SimpleGraphExecutionState { // The graph returned by BuildGraph may contain only the pruned // graph, whereas some clients may want access to the full graph. const Graph* full_graph() { - mutex_lock l(mu_); return graph_; } + // Returns the node with the given name, or null if it does not exist. + const Node* get_node_by_name(const string& name) const { + NodeNameToCostIdMap::const_iterator iter = + node_name_to_cost_id_map_.find(name); + if (iter != node_name_to_cost_id_map_.end()) { + return graph_->FindNodeId(iter->second); + } else { + return nullptr; + } + } + // Returns a reference to the current graph_def. Use must // not extend beyond lifetime of SimpleGrahExecutionState object. const GraphDef& original_graph_def() { return original_graph_def_; } @@ -153,31 +177,26 @@ class SimpleGraphExecutionState { return stateful_placements_; } - // Restores the map of stateful placements as a map of - // node name to placement string. - void SetStatefulPlacements(const std::unordered_map<string, string>& sp) { - mutex_lock l(mu_); - stateful_placements_ = sp; - } - private: - mutable mutex mu_; + SimpleGraphExecutionState(GraphDef* graph_def, + const SimpleGraphExecutionStateOptions& options); - Status InitBaseGraph(const BuildGraphOptions& options) - EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status InitBaseGraph(const BuildGraphOptions& options); // Map of placed stateful nodes, i.e. nodes for which is_stateful() // is true, such as "params" and "queue" nodes. Once placed these // nodes can not be moved to a different device. Maps node names to // device names. - std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_); - void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); - void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_); + std::unordered_map<string, string> stateful_placements_; // Immutable after + // ctor. + void SaveStatefulNodes(Graph* graph); + void RestoreStatefulNodes(Graph* graph); GraphDef original_graph_def_; // Immutable after ctor. const DeviceSet* device_set_; // Not owned const SessionOptions* session_options_; // Not owned + mutable mutex mu_; CostModel costs_ GUARDED_BY(mu_); // Map from name to Node for the full graph in placed_. @@ -188,7 +207,7 @@ class SimpleGraphExecutionState { std::unique_ptr<FunctionLibraryDefinition> flib_def_; // The dataflow graph owned by this object. - Graph* graph_ GUARDED_BY(mu_); + Graph* graph_; TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState); }; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index 8a6e1322ac..cf9deaabd8 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -256,6 +256,8 @@ void Master::CreateSession(const CreateSessionRequest* req, const_cast<CreateSessionRequest*>(req)->mutable_graph_def(); Status create_status = session->Create(gdef); if (!create_status.ok()) { + // Takes ownership of `session` (and destroys it). + session->Close(); done(create_status); return; } diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 4a7b7c072a..29d2959426 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -180,7 +180,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { SimpleGraphExecutionState* execution_state, ProfileHandler* ph, RunStepResponse* resp); void ProcessDeviceStats(ProfileHandler* ph, - SimpleGraphExecutionState* execution_state, + const SimpleGraphExecutionState* execution_state, const DeviceStepStats& ds, bool is_rpc); string DetailText(const NodeDef& def, const NodeExecStats& ns) { @@ -724,7 +724,7 @@ void MasterSession::ReffedClientGraph::ProcessStats( } void MasterSession::ReffedClientGraph::ProcessDeviceStats( - ProfileHandler* ph, SimpleGraphExecutionState* execution_state, + ProfileHandler* ph, const SimpleGraphExecutionState* execution_state, const DeviceStepStats& ds, bool is_rpc) { const string& dev_name = ds.device(); VLOG(1) << "Device " << dev_name << " reports stats for " @@ -736,9 +736,8 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats( ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(), ns.timeline_label()); } else { - NodeDef ndef; - Status s = execution_state->GlobalNodeDefByName(ns.node_name(), &ndef); - const bool found_node_in_graph = s.ok(); + const Node* node = execution_state->get_node_by_name(ns.node_name()); + const bool found_node_in_graph = node != nullptr; if (!found_node_in_graph && ns.timeline_label().empty()) { // The counter incrementing is not thread-safe. But we don't really // care. @@ -752,12 +751,13 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats( } continue; } - string optype = found_node_in_graph ? ndef.op() : ns.node_name(); + string optype = + found_node_in_graph ? node->type_string() : ns.node_name(); string details; if (!ns.timeline_label().empty()) { details = ns.timeline_label(); } else if (found_node_in_graph) { - details = DetailText(ndef, ns); + details = DetailText(node->def(), ns); } else { // Leave details string empty } @@ -892,10 +892,8 @@ Status MasterSession::Create(GraphDef* graph_def) { SimpleGraphExecutionStateOptions options; options.device_set = &devices_; options.session_options = &session_opts_; - execution_state_.reset( - new SimpleGraphExecutionState(graph_def->library(), options)); - TF_RETURN_IF_ERROR(execution_state_->Create(graph_def)); - + TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( + graph_def, options, &execution_state_)); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 7f99fd2981..e17614c819 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -50,7 +50,7 @@ class MasterSession { // Initialize the MasterSession for "def". Must be called before Extend(), // Run(), or Close(). // - // The callee may clear "def". + // After this method returns, `def` will no longer be valid. Status Create(GraphDef* def); // Returns the session handle. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 86a09551fd..e5d28f8450 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -179,10 +179,8 @@ TEST(GrpcSessionTest, NonLocalWithFilters) { { GraphDef graph_copy(graph); graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy); - TF_CHECK_OK(session->Create(graph_copy)); - auto status = session->Run({}, {}, {node_names[2]}, nullptr); + auto status = session->Create(graph_copy); EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code()); - TF_CHECK_OK(session->Close()); } } @@ -415,25 +413,23 @@ TEST(GrpcSessionTest, MultiDevices_String) { SetDevice(&def, a->name(), a_dev.name()); SetDevice(&def, b->name(), b_dev.name()); - TF_CHECK_OK(session->Create(def)); - { + Status s = session->Create(def); + if (s.ok()) { std::vector<Tensor> outputs; - Status s = session->Run({}, {b->name()}, {}, &outputs); - if (s.ok()) { - ASSERT_EQ(1, outputs.size()); - ASSERT_EQ(outputs[0].dtype(), DT_STRING); - ASSERT_EQ(outputs[0].NumElements(), 4); - for (int i = 0; i < outputs[0].NumElements(); ++i) { - EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world"); - } - } else { - LOG(ERROR) << "Error: " << s; - ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) || - (b_dev.device_type() == DEVICE_GPU)); - ASSERT_FALSE(s.ok()); + TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs)); + ASSERT_EQ(1, outputs.size()); + ASSERT_EQ(outputs[0].dtype(), DT_STRING); + ASSERT_EQ(outputs[0].NumElements(), 4); + for (int i = 0; i < outputs[0].NumElements(); ++i) { + EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world"); } + TF_CHECK_OK(session->Close()); + } else { + LOG(ERROR) << "Error: " << s; + ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) || + (b_dev.device_type() == DEVICE_GPU)); + ASSERT_FALSE(s.ok()); } - TF_CHECK_OK(session->Close()); } } } |