aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-14 05:35:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-14 06:49:15 -0700
commit9c0fd2b6e8d74b8d6595cc0e9bcc323ab108be45 (patch)
tree774e6d4b2eb2b114f51c2233a9d5563ea0d707e7
parent6b8298f0ff8a6ab43970bc1f9626e991fd6db85a (diff)
Refactoring: Use std::unique_ptr<> to maintain pointer ownership in more places in tensorflow::DirectSession and tensorflow::SimpleClientGraph.
Remove unused func_defs in ExecutorsAndKeys. Fixes a few memory leaks on error paths. Change: 127425821
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc159
-rw-r--r--tensorflow/core/common_runtime/direct_session.h46
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.cc30
-rw-r--r--tensorflow/core/common_runtime/simple_graph_execution_state.h5
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc39
5 files changed, 118 insertions, 161 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 2b8badcef7..e661588c56 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -190,10 +190,10 @@ DirectSession::DirectSession(const SessionOptions& options,
DirectSession::~DirectSession() {
for (auto& it : partial_runs_) {
- delete it.second;
+ it.second.reset(nullptr);
}
- for (auto it : executors_) {
- delete it.second;
+ for (auto& it : executors_) {
+ it.second.reset(nullptr);
}
for (auto d : device_mgr_->ListDevices()) {
d->op_segment()->RemoveHold(session_handle_);
@@ -238,13 +238,9 @@ Status DirectSession::Extend(const GraphDef& graph) {
Status DirectSession::ExtendLocked(const GraphDef& graph) {
MaybeInitializeExecutionState(graph);
- std::unique_ptr<SimpleGraphExecutionState> old_state;
- SimpleGraphExecutionState* new_state = nullptr;
- TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &new_state));
-
- // Swap out the old state.
- old_state = std::move(execution_state_);
- execution_state_.reset(new_state);
+ std::unique_ptr<SimpleGraphExecutionState> state;
+ TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
+ execution_state_.swap(state);
graph_created_ = true; // In case this is first call
return Status::OK();
@@ -330,10 +326,10 @@ Status DirectSession::Run(const RunOptions& run_options,
const int64 build_cost_model =
options_.config.graph_options().build_cost_model();
if (do_trace || build_cost_model > 0) {
- args.stats_collector = new StepStatsCollector(
+ run_state.collector.reset(new StepStatsCollector(
run_metadata->mutable_step_stats(),
- (build_cost_model > 0) ? &cost_model_manager_ : nullptr);
- run_state.collector = args.stats_collector;
+ (build_cost_model > 0) ? &cost_model_manager_ : nullptr));
+ args.stats_collector = run_state.collector.get();
}
// TODO(pbar) CostModel still gets very confused when presented
@@ -414,7 +410,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
{
mutex_lock l(executor_lock_);
- if (!partial_runs_.insert({run_state_args.handle, run_state}).second) {
+ if (!partial_runs_
+ .emplace(run_state_args.handle,
+ std::unique_ptr<RunState>(run_state))
+ .second) {
return errors::Internal("The handle '", run_state_args.handle,
"' created for this partial run is not unique.");
}
@@ -445,9 +444,9 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
}
if (options_.config.graph_options().build_cost_model()) {
- args.stats_collector =
- new StepStatsCollector(nullptr, &cost_model_manager_);
- run_state->collector = args.stats_collector;
+ run_state->collector.reset(
+ new StepStatsCollector(nullptr, &cost_model_manager_));
+ args.stats_collector = run_state->collector.get();
}
for (auto& item : executors_and_keys->items) {
@@ -473,14 +472,14 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
return errors::InvalidArgument(
"Must run 'setup' before performing partial runs!");
}
- executors_and_keys = exc_it->second;
+ executors_and_keys = exc_it->second.get();
auto prun_it = partial_runs_.find(handle);
if (prun_it == partial_runs_.end()) {
return errors::InvalidArgument(
"Must run 'setup' before performing partial runs!");
}
- run_state = prun_it->second;
+ run_state = prun_it->second.get();
// Make sure that this is a new set of feeds that are still pending.
for (const auto& input : inputs) {
@@ -542,7 +541,6 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
if (done) {
WaitForNotification(run_state, operation_timeout_in_ms_);
partial_runs_.erase(handle);
- delete run_state;
}
}
return s;
@@ -627,8 +625,8 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
const std::vector<string>& fetches,
const ExecutorsAndKeys* executors_and_keys,
const RunState* run_state) {
- const Graph* graph = executors_and_keys->graph;
- const NameNodeMap* name_to_node = executors_and_keys->name_to_node;
+ const Graph* graph = executors_and_keys->graph.get();
+ const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
// Build the set of pending feeds that we haven't seen.
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
@@ -715,7 +713,7 @@ Status DirectSession::GetOrCreateExecutors(
mutex_lock l(executor_lock_); // could use reader lock
auto it = executors_.find(key);
if (it != executors_.end()) {
- *executors_and_keys = it->second;
+ *executors_and_keys = it->second.get();
return Status::OK();
}
}
@@ -727,15 +725,12 @@ Status DirectSession::GetOrCreateExecutors(
// The executor_lock_ is intentionally released while executor is
// being created.
- std::unordered_map<string, Graph*> graphs;
- Status s = CreateGraphs(options, &graphs, run_state_args);
- TF_RETURN_IF_ERROR(s);
+ std::unordered_map<string, std::unique_ptr<Graph>> graphs;
+ TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, run_state_args));
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
- ek->func_defs = flib_def_.get();
if (run_state_args->is_partial_run) {
- ek->graph = run_state_args->graph;
- ek->name_to_node = new NameNodeMap;
+ ek->graph = std::move(run_state_args->graph);
std::unordered_set<StringPiece, StringPiece::Hasher> names;
for (const string& input : inputs) {
TensorId id(ParseTensorName(input));
@@ -745,9 +740,9 @@ Status DirectSession::GetOrCreateExecutors(
TensorId id(ParseTensorName(output));
names.emplace(id.first);
}
- for (Node* n : run_state_args->graph->nodes()) {
+ for (Node* n : ek->graph->nodes()) {
if (names.count(n->name()) > 0) {
- ek->name_to_node->insert({n->name(), n});
+ ek->name_to_node.insert({n->name(), n});
}
}
}
@@ -757,23 +752,22 @@ Status DirectSession::GetOrCreateExecutors(
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
const string& partition_name = iter->first;
- Graph* partition_graph = iter->second;
+ Graph* partition_graph = iter->second.get();
const int graph_def_version = partition_graph->versions().producer();
Device* device;
- s = device_mgr_->LookupDevice(partition_name, &device);
- if (!s.ok()) break;
+ TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
ek->items.resize(ek->items.size() + 1);
auto* item = &(ek->items.back());
- item->flib =
+ item->flib.reset(
NewFunctionLibraryRuntime(device_mgr_.get(), device, graph_def_version,
- flib_def_.get(), optimizer_opts);
+ flib_def_.get(), optimizer_opts));
LocalExecutorParams params;
params.device = device;
- params.function_library = item->flib;
- auto lib = item->flib;
+ params.function_library = item->flib.get();
+ auto lib = item->flib.get();
auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
@@ -798,25 +792,19 @@ Status DirectSession::GetOrCreateExecutors(
};
params.node_outputs_cb = node_outputs_callback_;
+ partition_graph = iter->second.release();
optimizer.Optimize(lib, device, &partition_graph);
+ iter->second.reset(partition_graph);
- s = EnsureMemoryTypes(DeviceType(device->device_type()), device->name(),
- partition_graph);
- if (!s.ok()) {
- break;
- }
- // NewLocalExecutor takes ownership of *partition_graph.
- iter->second = nullptr;
+ TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
+ device->name(), partition_graph));
+ // NewLocalExecutor takes ownership of partition_graph.
item->graph = partition_graph;
item->executor = nullptr;
- s = NewLocalExecutor(params, partition_graph, &item->executor);
- if (!s.ok()) {
- break;
- }
- }
- if (!s.ok()) {
- gtl::STLDeleteValues(&graphs);
- return s;
+ Executor* executor;
+ TF_RETURN_IF_ERROR(
+ NewLocalExecutor(params, iter->second.release(), &executor));
+ item->executor.reset(executor);
}
// Compute the rendezvous keys to avoid recomputing them every time.
@@ -834,25 +822,21 @@ Status DirectSession::GetOrCreateExecutors(
// Reacquire the lock, try to insert into the map.
mutex_lock l(executor_lock_);
- const bool inserted = executors_.insert(std::make_pair(key, ek.get())).second;
- if (!inserted) {
- // Another thread created the entry before us, so delete the
- // one we created and return the already created one.
- auto it = executors_.find(key);
- *executors_and_keys = it->second;
- } else {
- *executors_and_keys = ek.release();
- }
+
+ // Another thread may have created the entry before us, in which case we will
+ // reuse the already created one.
+ auto insert_result = executors_.emplace(key, std::move(ek));
+ *executors_and_keys = insert_result.first->second.get();
return Status::OK();
}
-Status DirectSession::CreateGraphs(const BuildGraphOptions& options,
- std::unordered_map<string, Graph*>* outputs,
- RunStateArgs* run_state_args) {
+Status DirectSession::CreateGraphs(
+ const BuildGraphOptions& options,
+ std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
+ RunStateArgs* run_state_args) {
mutex_lock l(graph_def_lock_);
std::unique_ptr<SimpleClientGraph> client_graph;
- SimpleClientGraph* cgraph = nullptr;
std::unique_ptr<SimpleGraphExecutionState> temp_exec_state_holder;
SimpleGraphExecutionState* execution_state = nullptr;
@@ -871,13 +855,13 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options,
}
TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend(
- execution_state_->original_graph_def(), &execution_state));
- temp_exec_state_holder.reset(execution_state);
+ execution_state_->original_graph_def(), &temp_exec_state_holder));
+ execution_state = temp_exec_state_holder.get();
} else {
execution_state = execution_state_.get();
}
- TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &cgraph));
+ TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &client_graph));
{
auto current_stateful_placements = execution_state->GetStatefulPlacements();
mutex_lock l(mu_);
@@ -900,12 +884,11 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options,
stateful_placements_ = execution_state->GetStatefulPlacements();
}
- client_graph.reset(cgraph);
// Remember the graph in run state if this is a partial run.
if (run_state_args->is_partial_run) {
- run_state_args->graph = new Graph(flib_def_.get());
- CopyGraph(*execution_state->full_graph(), run_state_args->graph);
+ run_state_args->graph.reset(new Graph(flib_def_.get()));
+ CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
}
// Partition the graph across devices.
@@ -967,25 +950,17 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options,
if (!s.ok()) {
break;
}
- Graph* device_graph = new Graph(flib_def_.get());
+ std::unique_ptr<Graph> device_graph(new Graph(flib_def_.get()));
GraphConstructorOptions device_opts;
// There are internal operations (e.g., send/recv) that we now
// allow.
device_opts.allow_internal_ops = true;
device_opts.expect_device_spec = true;
- s = ConvertGraphDefToGraph(device_opts, *graph_def, device_graph);
- if (!s.ok()) {
- delete device_graph;
- break;
- }
- outputs->insert(std::make_pair(partition_name, device_graph));
- }
- if (!s.ok()) {
- // Also delete other graphs created during the loop.
- gtl::STLDeleteValues(outputs);
- return s;
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(device_opts, *graph_def, device_graph.get()));
+ outputs->emplace(partition_name, std::move(device_graph));
}
- return Status::OK();
+ return s;
}
::tensorflow::Status DirectSession::Close() {
@@ -993,6 +968,17 @@ Status DirectSession::CreateGraphs(const BuildGraphOptions& options,
return ::tensorflow::Status::OK();
}
+DirectSession::RunState::RunState(const std::vector<string>& input_names,
+ const std::vector<string>& output_names) {
+ // Initially all the feeds and fetches are pending.
+ for (auto& name : input_names) {
+ pending_inputs.emplace(name);
+ }
+ for (auto& name : output_names) {
+ pending_outputs.emplace(name);
+ }
+}
+
DirectSession::RunState::~RunState() {
if (rendez != nullptr) {
if (!executors_done.HasBeenNotified()) {
@@ -1001,9 +987,6 @@ DirectSession::RunState::~RunState() {
}
rendez->Unref();
}
- if (collector != nullptr) {
- delete collector;
- }
}
void DirectSession::WaitForNotification(RunState* run_state,
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 54ea33a58c..4bfd89150b 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -95,14 +95,13 @@ class DirectSession : public Session {
// every partition.
struct PerPartitionExecutorsAndLib {
Graph* graph = nullptr;
- Executor* executor = nullptr;
- FunctionLibraryRuntime* flib = nullptr;
+ std::unique_ptr<FunctionLibraryRuntime> flib;
+ std::unique_ptr<Executor> executor;
};
// An ExecutorsAndKeys is created for a given set of feeds/fetches.
// 'step_count' is the number of times this graph is executed.
- // 'func_defs' are the function definition used by all the underlying
- // executors. 'graph' is the entire graph being executed. 'name_to_node'
+ // 'graph' is the entire graph being executed. 'name_to_node'
// maps node name to node. We keep 'graph' and 'name_to_node' only in
// the case of partial runs. Each item in 'items' is the executor for
// a partition of the graph bundled with its dependent library runtime.
@@ -110,21 +109,11 @@ class DirectSession : public Session {
// are rendezvous keys for the fetches.
struct ExecutorsAndKeys {
int64 step_count = 0;
- FunctionLibraryDefinition* func_defs = nullptr;
- Graph* graph = nullptr;
- NameNodeMap* name_to_node = nullptr;
+ std::unique_ptr<Graph> graph;
+ NameNodeMap name_to_node;
std::vector<PerPartitionExecutorsAndLib> items;
std::unordered_map<string, string> input_keys;
std::unordered_map<string, string> output_keys;
-
- ~ExecutorsAndKeys() {
- for (auto item : items) {
- delete item.executor;
- delete item.flib;
- }
- delete graph;
- delete name_to_node;
- }
};
// For each live partial execution, the session maintains a RunState.
@@ -135,22 +124,14 @@ class DirectSession : public Session {
mutex mu_;
Status status GUARDED_BY(mu_);
IntraProcessRendezvous* rendez = nullptr;
- StepStatsCollector* collector = nullptr;
+ std::unique_ptr<StepStatsCollector> collector;
Notification executors_done;
std::unordered_set<string> pending_inputs;
std::unordered_set<string> pending_outputs;
TensorStore tensor_store;
RunState(const std::vector<string>& input_names,
- const std::vector<string>& output_names) {
- // Initially all the feeds and fetches are pending.
- for (auto& name : input_names) {
- pending_inputs.emplace(name);
- }
- for (auto& name : output_names) {
- pending_outputs.emplace(name);
- }
- }
+ const std::vector<string>& output_names);
~RunState();
};
@@ -158,7 +139,7 @@ class DirectSession : public Session {
struct RunStateArgs {
bool is_partial_run = false;
string handle;
- Graph* graph = nullptr;
+ std::unique_ptr<Graph> graph;
};
// Initializes the base execution state given the 'graph',
@@ -175,9 +156,10 @@ class DirectSession : public Session {
// Creates several graphs given the existing graph_def_ and the
// input feeds and fetches, given 'devices'.
- ::tensorflow::Status CreateGraphs(const BuildGraphOptions& options,
- std::unordered_map<string, Graph*>* outputs,
- RunStateArgs* run_state_args);
+ ::tensorflow::Status CreateGraphs(
+ const BuildGraphOptions& options,
+ std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
+ RunStateArgs* run_state_args);
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
@@ -230,11 +212,11 @@ class DirectSession : public Session {
// Holds mappings from signature to the executors that process
// it. The reason for a level of indirection around mapped_type is
// to guarantee address stability.
- std::unordered_map<string, ExecutorsAndKeys*> executors_
+ std::unordered_map<string, std::unique_ptr<ExecutorsAndKeys>> executors_
GUARDED_BY(executor_lock_);
// Holds mappings from handle to partial run state.
- std::unordered_map<string, RunState*> partial_runs_
+ std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
GUARDED_BY(executor_lock_);
// This holds all the tensors that are currently alive in the session.
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
index cee26cb595..42b94f0dcd 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
@@ -66,7 +66,8 @@ Status SimpleGraphExecutionState::Create(GraphDef* graph_def) {
}
Status SimpleGraphExecutionState::Extend(
- const GraphDef& extension_def, SimpleGraphExecutionState** out) const {
+ const GraphDef& extension_def,
+ std::unique_ptr<SimpleGraphExecutionState>* out) const {
std::unordered_set<string> new_names;
// 1. Build an index of the new node names.
for (const NodeDef& node : extension_def.node()) {
@@ -135,15 +136,11 @@ Status SimpleGraphExecutionState::Extend(
combined_options.device_set = device_set_;
combined_options.session_options = session_options_;
- SimpleGraphExecutionState* new_execution_state =
- new SimpleGraphExecutionState(ops_, combined_options);
- Status new_execution_state_status = new_execution_state->Create(&gdef);
- if (!new_execution_state_status.ok()) {
- delete new_execution_state;
- return new_execution_state_status;
- }
+ std::unique_ptr<SimpleGraphExecutionState> new_execution_state(
+ new SimpleGraphExecutionState(ops_, combined_options));
+ TF_RETURN_IF_ERROR(new_execution_state->Create(&gdef));
new_execution_state->SetStatefulPlacements(GetStatefulPlacements());
- *out = new_execution_state;
+ *out = std::move(new_execution_state);
// TODO(mrry): This is likely to be used for non-throughput-sensitive
// interactive workloads, but in future we may want to transfer other
@@ -191,13 +188,14 @@ Status SimpleGraphExecutionState::InitBaseGraph(
SimplePlacer placer(new_graph.get(), device_set_, session_options_);
// TODO(mrry): Consider making the SimplePlacer cancelable.
TF_RETURN_IF_ERROR(placer.Run());
+
SaveStatefulNodes(new_graph.get());
graph_ = new_graph.release();
return Status::OK();
}
-Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options,
- SimpleClientGraph** out) {
+Status SimpleGraphExecutionState::BuildGraph(
+ const BuildGraphOptions& options, std::unique_ptr<SimpleClientGraph>* out) {
VLOG(1) << "BuildGraph";
mutex_lock l(mu_);
// Lazily initialize the base graph.
@@ -220,16 +218,12 @@ Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options,
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
- {
- std::unique_ptr<SimpleClientGraph> dense_copy(new SimpleClientGraph(ops_));
- CopyGraph(cgraph->graph, &dense_copy->graph);
- cgraph = std::move(dense_copy);
- }
+ std::unique_ptr<SimpleClientGraph> dense_copy(new SimpleClientGraph(ops_));
+ CopyGraph(cgraph->graph, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.
- *out = cgraph.release();
-
+ *out = std::move(dense_copy);
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h
index cb8df9fab1..adeb0995ae 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.h
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h
@@ -100,13 +100,14 @@ class SimpleGraphExecutionState {
// in *this, but currently does not transfer any other placement
// or cost model information to the new graph.
Status Extend(const GraphDef& extension_def,
- SimpleGraphExecutionState** out) const;
+ std::unique_ptr<SimpleGraphExecutionState>* out) const;
// Builds a SimpleClientGraph (a sub-graph of the full graph as induced by
// the Node set specified in "options"). If successful, returns OK
// and the caller takes the ownership of "*out". Otherwise, returns
// an error.
- Status BuildGraph(const BuildGraphOptions& options, SimpleClientGraph** out);
+ Status BuildGraph(const BuildGraphOptions& options,
+ std::unique_ptr<SimpleClientGraph>* out);
// Returns OK if the named node is found in the placed full graph owned
// by this execution_state, and sets *out to the NodeDef for that node.
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index dd7b9e066f..93f18e612c 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -187,9 +187,10 @@ class MasterSession : public MasterSessionInterface {
class MasterSession::ReffedClientGraph : public core::RefCounted {
public:
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
- SimpleClientGraph* cg, const GraphOptions& graph_opts)
+ std::unique_ptr<SimpleClientGraph> cg,
+ const GraphOptions& graph_opts)
: session_handle_(handle),
- client_graph_(cg),
+ client_graph_(std::move(cg)),
bopts_(bopts),
graph_opts_(graph_opts) {
VLOG(1) << "Created ReffedClientGraph for node with "
@@ -204,11 +205,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
}
~ReffedClientGraph() override {
- delete client_graph_;
DeregisterPartitions();
}
- const SimpleClientGraph* client_graph() { return client_graph_; }
+ const SimpleClientGraph* client_graph() { return client_graph_.get(); }
// Local execution methods.
@@ -233,7 +233,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
private:
const string session_handle_;
- SimpleClientGraph* const client_graph_ = nullptr;
+ const std::unique_ptr<SimpleClientGraph> client_graph_;
std::unordered_set<const Node*> nodes_needing_input_mapping_;
BuildGraphOptions bopts_;
const GraphOptions graph_opts_;
@@ -771,7 +771,7 @@ Status MasterSession::Create(GraphDef* graph_def) {
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
- std::unique_ptr<SimpleGraphExecutionState> old_execution_state;
+ std::unique_ptr<SimpleGraphExecutionState> extended_execution_state;
{
mutex_lock l(mu_);
// TODO(mrry): Redesign the locking with reader/writer locks to prevent
@@ -790,20 +790,16 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
}
CHECK(execution_state_);
- SimpleGraphExecutionState* extended_execution_state = nullptr;
- Status s =
- execution_state_->Extend(req->graph_def(), &extended_execution_state);
- if (s.ok()) {
- CHECK(extended_execution_state);
- old_execution_state =
- std::move(execution_state_); // Will be released outside the lock
- execution_state_.reset(extended_execution_state);
- ++graph_version_;
- resp->set_new_graph_version(graph_version_);
- }
+ TF_RETURN_IF_ERROR(
+ execution_state_->Extend(req->graph_def(), &extended_execution_state));
- return s;
+ CHECK(extended_execution_state);
+ // The old execution state will be released outside the lock.
+ execution_state_.swap(extended_execution_state);
+ ++graph_version_;
+ resp->set_new_graph_version(graph_version_);
}
+ return Status::OK();
}
Status MasterSession::StartStep(const RunStepRequest& req,
@@ -824,10 +820,11 @@ Status MasterSession::StartStep(const RunStepRequest& req,
// cache it.
VLOG(1) << "Unseen hash " << hash << " for "
<< BuildGraphOptionsString(*opts);
- SimpleClientGraph* client_graph = nullptr;
+ std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
- auto entry = new ReffedClientGraph(handle_, *opts, client_graph,
- session_opts_.config.graph_options());
+ auto entry =
+ new ReffedClientGraph(handle_, *opts, std::move(client_graph),
+ session_opts_.config.graph_options());
iter = runs_.insert({hash, entry}).first;
auto obs_iter = obsolete_.find(hash);
if (obs_iter != obsolete_.end()) {