diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-07-14 05:35:52 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-14 06:49:15 -0700 |
commit | 9c0fd2b6e8d74b8d6595cc0e9bcc323ab108be45 (patch) | |
tree | 774e6d4b2eb2b114f51c2233a9d5563ea0d707e7 | |
parent | 6b8298f0ff8a6ab43970bc1f9626e991fd6db85a (diff) |
Refactoring: Use std::unique_ptr<> to maintain pointer ownership in more places in tensorflow::DirectSession and tensorflow::SimpleClientGraph.
Remove unused func_defs in ExecutorsAndKeys.
Fixes a few memory leaks on error paths.
Change: 127425821
5 files changed, 118 insertions, 161 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 2b8badcef7..e661588c56 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -190,10 +190,10 @@ DirectSession::DirectSession(const SessionOptions& options, DirectSession::~DirectSession() { for (auto& it : partial_runs_) { - delete it.second; + it.second.reset(nullptr); } - for (auto it : executors_) { - delete it.second; + for (auto& it : executors_) { + it.second.reset(nullptr); } for (auto d : device_mgr_->ListDevices()) { d->op_segment()->RemoveHold(session_handle_); @@ -238,13 +238,9 @@ Status DirectSession::Extend(const GraphDef& graph) { Status DirectSession::ExtendLocked(const GraphDef& graph) { MaybeInitializeExecutionState(graph); - std::unique_ptr<SimpleGraphExecutionState> old_state; - SimpleGraphExecutionState* new_state = nullptr; - TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &new_state)); - - // Swap out the old state. - old_state = std::move(execution_state_); - execution_state_.reset(new_state); + std::unique_ptr<SimpleGraphExecutionState> state; + TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); + execution_state_.swap(state); graph_created_ = true; // In case this is first call return Status::OK(); @@ -330,10 +326,10 @@ Status DirectSession::Run(const RunOptions& run_options, const int64 build_cost_model = options_.config.graph_options().build_cost_model(); if (do_trace || build_cost_model > 0) { - args.stats_collector = new StepStatsCollector( + run_state.collector.reset(new StepStatsCollector( run_metadata->mutable_step_stats(), - (build_cost_model > 0) ? &cost_model_manager_ : nullptr); - run_state.collector = args.stats_collector; + (build_cost_model > 0) ? &cost_model_manager_ : nullptr)); + args.stats_collector = run_state.collector.get(); } // TODO(pbar) CostModel still gets very confused when presented @@ -414,7 +410,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names, run_state->rendez = new IntraProcessRendezvous(device_mgr_.get()); { mutex_lock l(executor_lock_); - if (!partial_runs_.insert({run_state_args.handle, run_state}).second) { + if (!partial_runs_ + .emplace(run_state_args.handle, + std::unique_ptr<RunState>(run_state)) + .second) { return errors::Internal("The handle '", run_state_args.handle, "' created for this partial run is not unique."); } @@ -445,9 +444,9 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names, } if (options_.config.graph_options().build_cost_model()) { - args.stats_collector = - new StepStatsCollector(nullptr, &cost_model_manager_); - run_state->collector = args.stats_collector; + run_state->collector.reset( + new StepStatsCollector(nullptr, &cost_model_manager_)); + args.stats_collector = run_state->collector.get(); } for (auto& item : executors_and_keys->items) { @@ -473,14 +472,14 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, return errors::InvalidArgument( "Must run 'setup' before performing partial runs!"); } - executors_and_keys = exc_it->second; + executors_and_keys = exc_it->second.get(); auto prun_it = partial_runs_.find(handle); if (prun_it == partial_runs_.end()) { return errors::InvalidArgument( "Must run 'setup' before performing partial runs!"); } - run_state = prun_it->second; + run_state = prun_it->second.get(); // Make sure that this is a new set of feeds that are still pending. for (const auto& input : inputs) { @@ -542,7 +541,6 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, if (done) { WaitForNotification(run_state, operation_timeout_in_ms_); partial_runs_.erase(handle); - delete run_state; } } return s; @@ -627,8 +625,8 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds, const std::vector<string>& fetches, const ExecutorsAndKeys* executors_and_keys, const RunState* run_state) { - const Graph* graph = executors_and_keys->graph; - const NameNodeMap* name_to_node = executors_and_keys->name_to_node; + const Graph* graph = executors_and_keys->graph.get(); + const NameNodeMap* name_to_node = &executors_and_keys->name_to_node; // Build the set of pending feeds that we haven't seen. std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; @@ -715,7 +713,7 @@ Status DirectSession::GetOrCreateExecutors( mutex_lock l(executor_lock_); // could use reader lock auto it = executors_.find(key); if (it != executors_.end()) { - *executors_and_keys = it->second; + *executors_and_keys = it->second.get(); return Status::OK(); } } @@ -727,15 +725,12 @@ Status DirectSession::GetOrCreateExecutors( // The executor_lock_ is intentionally released while executor is // being created. - std::unordered_map<string, Graph*> graphs; - Status s = CreateGraphs(options, &graphs, run_state_args); - TF_RETURN_IF_ERROR(s); + std::unordered_map<string, std::unique_ptr<Graph>> graphs; + TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, run_state_args)); std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); - ek->func_defs = flib_def_.get(); if (run_state_args->is_partial_run) { - ek->graph = run_state_args->graph; - ek->name_to_node = new NameNodeMap; + ek->graph = std::move(run_state_args->graph); std::unordered_set<StringPiece, StringPiece::Hasher> names; for (const string& input : inputs) { TensorId id(ParseTensorName(input)); @@ -745,9 +740,9 @@ Status DirectSession::GetOrCreateExecutors( TensorId id(ParseTensorName(output)); names.emplace(id.first); } - for (Node* n : run_state_args->graph->nodes()) { + for (Node* n : ek->graph->nodes()) { if (names.count(n->name()) > 0) { - ek->name_to_node->insert({n->name(), n}); + ek->name_to_node.insert({n->name(), n}); } } } @@ -757,23 +752,22 @@ Status DirectSession::GetOrCreateExecutors( GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { const string& partition_name = iter->first; - Graph* partition_graph = iter->second; + Graph* partition_graph = iter->second.get(); const int graph_def_version = partition_graph->versions().producer(); Device* device; - s = device_mgr_->LookupDevice(partition_name, &device); - if (!s.ok()) break; + TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device)); ek->items.resize(ek->items.size() + 1); auto* item = &(ek->items.back()); - item->flib = + item->flib.reset( NewFunctionLibraryRuntime(device_mgr_.get(), device, graph_def_version, - flib_def_.get(), optimizer_opts); + flib_def_.get(), optimizer_opts)); LocalExecutorParams params; params.device = device; - params.function_library = item->flib; - auto lib = item->flib; + params.function_library = item->flib.get(); + auto lib = item->flib.get(); auto opseg = device->op_segment(); params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { @@ -798,25 +792,19 @@ Status DirectSession::GetOrCreateExecutors( }; params.node_outputs_cb = node_outputs_callback_; + partition_graph = iter->second.release(); optimizer.Optimize(lib, device, &partition_graph); + iter->second.reset(partition_graph); - s = EnsureMemoryTypes(DeviceType(device->device_type()), device->name(), - partition_graph); - if (!s.ok()) { - break; - } - // NewLocalExecutor takes ownership of *partition_graph. - iter->second = nullptr; + TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), + device->name(), partition_graph)); + // NewLocalExecutor takes ownership of partition_graph. item->graph = partition_graph; item->executor = nullptr; - s = NewLocalExecutor(params, partition_graph, &item->executor); - if (!s.ok()) { - break; - } - } - if (!s.ok()) { - gtl::STLDeleteValues(&graphs); - return s; + Executor* executor; + TF_RETURN_IF_ERROR( + NewLocalExecutor(params, iter->second.release(), &executor)); + item->executor.reset(executor); } // Compute the rendezvous keys to avoid recomputing them every time. @@ -834,25 +822,21 @@ Status DirectSession::GetOrCreateExecutors( // Reacquire the lock, try to insert into the map. mutex_lock l(executor_lock_); - const bool inserted = executors_.insert(std::make_pair(key, ek.get())).second; - if (!inserted) { - // Another thread created the entry before us, so delete the - // one we created and return the already created one. - auto it = executors_.find(key); - *executors_and_keys = it->second; - } else { - *executors_and_keys = ek.release(); - } + + // Another thread may have created the entry before us, in which case we will + // reuse the already created one. + auto insert_result = executors_.emplace(key, std::move(ek)); + *executors_and_keys = insert_result.first->second.get(); return Status::OK(); } -Status DirectSession::CreateGraphs(const BuildGraphOptions& options, - std::unordered_map<string, Graph*>* outputs, - RunStateArgs* run_state_args) { +Status DirectSession::CreateGraphs( + const BuildGraphOptions& options, + std::unordered_map<string, std::unique_ptr<Graph>>* outputs, + RunStateArgs* run_state_args) { mutex_lock l(graph_def_lock_); std::unique_ptr<SimpleClientGraph> client_graph; - SimpleClientGraph* cgraph = nullptr; std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder; SimpleGraphExecutionState* execution_state = nullptr; @@ -871,13 +855,13 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options, } TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend( - execution_state_->original_graph_def(), &execution_state)); - temp_exec_state_holder.reset(execution_state); + execution_state_->original_graph_def(), &temp_exec_state_holder)); + execution_state = temp_exec_state_holder.get(); } else { execution_state = execution_state_.get(); } - TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &cgraph)); + TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &client_graph)); { auto current_stateful_placements = execution_state->GetStatefulPlacements(); mutex_lock l(mu_); @@ -900,12 +884,11 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options, stateful_placements_ = execution_state->GetStatefulPlacements(); } - client_graph.reset(cgraph); // Remember the graph in run state if this is a partial run. if (run_state_args->is_partial_run) { - run_state_args->graph = new Graph(flib_def_.get()); - CopyGraph(*execution_state->full_graph(), run_state_args->graph); + run_state_args->graph.reset(new Graph(flib_def_.get())); + CopyGraph(*execution_state->full_graph(), run_state_args->graph.get()); } // Partition the graph across devices. @@ -967,25 +950,17 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options, if (!s.ok()) { break; } - Graph* device_graph = new Graph(flib_def_.get()); + std::unique_ptr<Graph> device_graph(new Graph(flib_def_.get())); GraphConstructorOptions device_opts; // There are internal operations (e.g., send/recv) that we now // allow. device_opts.allow_internal_ops = true; device_opts.expect_device_spec = true; - s = ConvertGraphDefToGraph(device_opts, *graph_def, device_graph); - if (!s.ok()) { - delete device_graph; - break; - } - outputs->insert(std::make_pair(partition_name, device_graph)); - } - if (!s.ok()) { - // Also delete other graphs created during the loop. - gtl::STLDeleteValues(outputs); - return s; + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(device_opts, *graph_def, device_graph.get())); + outputs->emplace(partition_name, std::move(device_graph)); } - return Status::OK(); + return s; } ::tensorflow::Status DirectSession::Close() { @@ -993,6 +968,17 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options, return ::tensorflow::Status::OK(); } +DirectSession::RunState::RunState(const std::vector<string>& input_names, + const std::vector<string>& output_names) { + // Initially all the feeds and fetches are pending. + for (auto& name : input_names) { + pending_inputs.emplace(name); + } + for (auto& name : output_names) { + pending_outputs.emplace(name); + } +} + DirectSession::RunState::~RunState() { if (rendez != nullptr) { if (!executors_done.HasBeenNotified()) { @@ -1001,9 +987,6 @@ DirectSession::RunState::~RunState() { } rendez->Unref(); } - if (collector != nullptr) { - delete collector; - } } void DirectSession::WaitForNotification(RunState* run_state, diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 54ea33a58c..4bfd89150b 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -95,14 +95,13 @@ class DirectSession : public Session { // every partition. struct PerPartitionExecutorsAndLib { Graph* graph = nullptr; - Executor* executor = nullptr; - FunctionLibraryRuntime* flib = nullptr; + std::unique_ptr<FunctionLibraryRuntime> flib; + std::unique_ptr<Executor> executor; }; // An ExecutorsAndKeys is created for a given set of feeds/fetches. // 'step_count' is the number of times this graph is executed. - // 'func_defs' are the function definition used by all the underlying - // executors. 'graph' is the entire graph being executed. 'name_to_node' + // 'graph' is the entire graph being executed. 'name_to_node' // maps node name to node. We keep 'graph' and 'name_to_node' only in // the case of partial runs. Each item in 'items' is the executor for // a partition of the graph bundled with its dependent library runtime. @@ -110,21 +109,11 @@ class DirectSession : public Session { // are rendezvous keys for the fetches. struct ExecutorsAndKeys { int64 step_count = 0; - FunctionLibraryDefinition* func_defs = nullptr; - Graph* graph = nullptr; - NameNodeMap* name_to_node = nullptr; + std::unique_ptr<Graph> graph; + NameNodeMap name_to_node; std::vector<PerPartitionExecutorsAndLib> items; std::unordered_map<string, string> input_keys; std::unordered_map<string, string> output_keys; - - ~ExecutorsAndKeys() { - for (auto item : items) { - delete item.executor; - delete item.flib; - } - delete graph; - delete name_to_node; - } }; // For each live partial execution, the session maintains a RunState. @@ -135,22 +124,14 @@ class DirectSession : public Session { mutex mu_; Status status GUARDED_BY(mu_); IntraProcessRendezvous* rendez = nullptr; - StepStatsCollector* collector = nullptr; + std::unique_ptr<StepStatsCollector> collector; Notification executors_done; std::unordered_set<string> pending_inputs; std::unordered_set<string> pending_outputs; TensorStore tensor_store; RunState(const std::vector<string>& input_names, - const std::vector<string>& output_names) { - // Initially all the feeds and fetches are pending. - for (auto& name : input_names) { - pending_inputs.emplace(name); - } - for (auto& name : output_names) { - pending_outputs.emplace(name); - } - } + const std::vector<string>& output_names); ~RunState(); }; @@ -158,7 +139,7 @@ class DirectSession : public Session { struct RunStateArgs { bool is_partial_run = false; string handle; - Graph* graph = nullptr; + std::unique_ptr<Graph> graph; }; // Initializes the base execution state given the 'graph', @@ -175,9 +156,10 @@ class DirectSession : public Session { // Creates several graphs given the existing graph_def_ and the // input feeds and fetches, given 'devices'. - ::tensorflow::Status CreateGraphs(const BuildGraphOptions& options, - std::unordered_map<string, Graph*>* outputs, - RunStateArgs* run_state_args); + ::tensorflow::Status CreateGraphs( + const BuildGraphOptions& options, + std::unordered_map<string, std::unique_ptr<Graph>>* outputs, + RunStateArgs* run_state_args); ::tensorflow::Status ExtendLocked(const GraphDef& graph) EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); @@ -230,11 +212,11 @@ class DirectSession : public Session { // Holds mappings from signature to the executors that process // it. The reason for a level of indirection around mapped_type is // to guarantee address stability. - std::unordered_map<string, ExecutorsAndKeys*> executors_ + std::unordered_map<string, std::unique_ptr<ExecutorsAndKeys>> executors_ GUARDED_BY(executor_lock_); // Holds mappings from handle to partial run state. - std::unordered_map<string, RunState*> partial_runs_ + std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ GUARDED_BY(executor_lock_); // This holds all the tensors that are currently alive in the session. diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index cee26cb595..42b94f0dcd 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -66,7 +66,8 @@ Status SimpleGraphExecutionState::Create(GraphDef* graph_def) { } Status SimpleGraphExecutionState::Extend( - const GraphDef& extension_def, SimpleGraphExecutionState** out) const { + const GraphDef& extension_def, + std::unique_ptr<SimpleGraphExecutionState>* out) const { std::unordered_set<string> new_names; // 1. Build an index of the new node names. for (const NodeDef& node : extension_def.node()) { @@ -135,15 +136,11 @@ Status SimpleGraphExecutionState::Extend( combined_options.device_set = device_set_; combined_options.session_options = session_options_; - SimpleGraphExecutionState* new_execution_state = - new SimpleGraphExecutionState(ops_, combined_options); - Status new_execution_state_status = new_execution_state->Create(&gdef); - if (!new_execution_state_status.ok()) { - delete new_execution_state; - return new_execution_state_status; - } + std::unique_ptr<SimpleGraphExecutionState> new_execution_state( + new SimpleGraphExecutionState(ops_, combined_options)); + TF_RETURN_IF_ERROR(new_execution_state->Create(&gdef)); new_execution_state->SetStatefulPlacements(GetStatefulPlacements()); - *out = new_execution_state; + *out = std::move(new_execution_state); // TODO(mrry): This is likely to be used for non-throughput-sensitive // interactive workloads, but in future we may want to transfer other @@ -191,13 +188,14 @@ Status SimpleGraphExecutionState::InitBaseGraph( SimplePlacer placer(new_graph.get(), device_set_, session_options_); // TODO(mrry): Consider making the SimplePlacer cancelable. TF_RETURN_IF_ERROR(placer.Run()); + SaveStatefulNodes(new_graph.get()); graph_ = new_graph.release(); return Status::OK(); } -Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options, - SimpleClientGraph** out) { +Status SimpleGraphExecutionState::BuildGraph( + const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) { VLOG(1) << "BuildGraph"; mutex_lock l(mu_); // Lazily initialize the base graph. @@ -220,16 +218,12 @@ Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options, // Copy the extracted graph in order to make its node ids dense, // since the local CostModel used to record its stats is sized by // the largest node id. - { - std::unique_ptr<SimpleClientGraph> dense_copy(new SimpleClientGraph(ops_)); - CopyGraph(cgraph->graph, &dense_copy->graph); - cgraph = std::move(dense_copy); - } + std::unique_ptr<SimpleClientGraph> dense_copy(new SimpleClientGraph(ops_)); + CopyGraph(cgraph->graph, &dense_copy->graph); // TODO(vrv): We should check invariants of the graph here. - *out = cgraph.release(); - + *out = std::move(dense_copy); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h index cb8df9fab1..adeb0995ae 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h @@ -100,13 +100,14 @@ class SimpleGraphExecutionState { // in *this, but currently does not transfer any other placement // or cost model information to the new graph. Status Extend(const GraphDef& extension_def, - SimpleGraphExecutionState** out) const; + std::unique_ptr<SimpleGraphExecutionState>* out) const; // Builds a SimpleClientGraph (a sub-graph of the full graph as induced by // the Node set specified in "options"). If successful, returns OK // and the caller takes the ownership of "*out". Otherwise, returns // an error. - Status BuildGraph(const BuildGraphOptions& options, SimpleClientGraph** out); + Status BuildGraph(const BuildGraphOptions& options, + std::unique_ptr<SimpleClientGraph>* out); // Returns OK if the named node is found in the placed full graph owned // by this execution_state, and sets *out to the NodeDef for that node. diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index dd7b9e066f..93f18e612c 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -187,9 +187,10 @@ class MasterSession : public MasterSessionInterface { class MasterSession::ReffedClientGraph : public core::RefCounted { public: ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts, - SimpleClientGraph* cg, const GraphOptions& graph_opts) + std::unique_ptr<SimpleClientGraph> cg, + const GraphOptions& graph_opts) : session_handle_(handle), - client_graph_(cg), + client_graph_(std::move(cg)), bopts_(bopts), graph_opts_(graph_opts) { VLOG(1) << "Created ReffedClientGraph for node with " @@ -204,11 +205,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { } ~ReffedClientGraph() override { - delete client_graph_; DeregisterPartitions(); } - const SimpleClientGraph* client_graph() { return client_graph_; } + const SimpleClientGraph* client_graph() { return client_graph_.get(); } // Local execution methods. @@ -233,7 +233,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { private: const string session_handle_; - SimpleClientGraph* const client_graph_ = nullptr; + const std::unique_ptr<SimpleClientGraph> client_graph_; std::unordered_set<const Node*> nodes_needing_input_mapping_; BuildGraphOptions bopts_; const GraphOptions graph_opts_; @@ -771,7 +771,7 @@ Status MasterSession::Create(GraphDef* graph_def) { Status MasterSession::Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp) { UpdateLastAccessTime(); - std::unique_ptr<SimpleGraphExecutionState> old_execution_state; + std::unique_ptr<SimpleGraphExecutionState> extended_execution_state; { mutex_lock l(mu_); // TODO(mrry): Redesign the locking with reader/writer locks to prevent @@ -790,20 +790,16 @@ Status MasterSession::Extend(const ExtendSessionRequest* req, } CHECK(execution_state_); - SimpleGraphExecutionState* extended_execution_state = nullptr; - Status s = - execution_state_->Extend(req->graph_def(), &extended_execution_state); - if (s.ok()) { - CHECK(extended_execution_state); - old_execution_state = - std::move(execution_state_); // Will be released outside the lock - execution_state_.reset(extended_execution_state); - ++graph_version_; - resp->set_new_graph_version(graph_version_); - } + TF_RETURN_IF_ERROR( + execution_state_->Extend(req->graph_def(), &extended_execution_state)); - return s; + CHECK(extended_execution_state); + // The old execution state will be released outside the lock. + execution_state_.swap(extended_execution_state); + ++graph_version_; + resp->set_new_graph_version(graph_version_); } + return Status::OK(); } Status MasterSession::StartStep(const RunStepRequest& req, @@ -824,10 +820,11 @@ Status MasterSession::StartStep(const RunStepRequest& req, // cache it. VLOG(1) << "Unseen hash " << hash << " for " << BuildGraphOptionsString(*opts); - SimpleClientGraph* client_graph = nullptr; + std::unique_ptr<SimpleClientGraph> client_graph; TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph)); - auto entry = new ReffedClientGraph(handle_, *opts, client_graph, - session_opts_.config.graph_options()); + auto entry = + new ReffedClientGraph(handle_, *opts, std::move(client_graph), + session_opts_.config.graph_options()); iter = runs_.insert({hash, entry}).first; auto obs_iter = obsolete_.find(hash); if (obs_iter != obsolete_.end()) { |