diff options
author | 2016-10-10 01:44:17 -0800 | |
---|---|---|
committer | 2016-10-10 02:54:26 -0700 | |
commit | 24dedaf72ae86d3daf173cb539a18ea5ed09dd5c (patch) | |
tree | 15d15ef909a67bdb7d42b78bd9598f19ba36fe38 /tensorflow/core/common_runtime/direct_session.cc | |
parent | 81fab8a90f4e6d38e667bfec45b404030b5c0601 (diff) |
Simplify locking in SimpleGraphExecutionState.
This change removes lazy initialization of the base graph in
SimpleGraphExecutionState, and instead initializes the base graph when
the object is constructed. This allows us to access the (read only
once written) graph without acquiring a mutex, which improves
performance for concurrent read operations on the graph.
In particular, this change optimizes the statistics processing code in
`master_session.cc`, which resolves node types and other details using
the Graph* in a SimpleGraphExecutionState. When many steps
concurrently compute timelines (as might happen during a profiling
session), there can be massive contention on this shared read-only
structure, leading to inaccurate timing results. (In addition to
relieving the contention, this change avoids copying each `NodeDef` to
generate its profiling information.)
To simplify the locking, we need to combine the unfailing existing
constructor for SimpleGraphExecutionState with some methods that can
fail due to invalid input. Therefore the construction of a
SimpleGraphExecutionState now uses static constructors, so that
construction can return an error status.
**N.B.** If you use the C++ Session API, you may notice that many
errors are now raised earlier (i.e. on Session::Create or
Session::Extend) rather than the first Session::Run call.
Change: 135654905
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 71 |
1 files changed, 47 insertions, 24 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 52a810d76b..848a71c474 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -285,28 +285,47 @@ DirectSession::~DirectSession() { flib_def_.reset(nullptr); } -void DirectSession::MaybeInitializeExecutionState(const GraphDef& graph) { +Status DirectSession::MaybeInitializeExecutionState( + const GraphDef& graph, bool* out_already_initialized) { // If already initialized, do nothing. if (flib_def_ && execution_state_) { - return; + *out_already_initialized = true; + return Status::OK(); } // Set up the per-session execution state. + // NOTE(mrry): The function library created here will be used for + // all subsequent extensions of the graph. flib_def_.reset( new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); SimpleGraphExecutionStateOptions options; options.device_set = &device_set_; options.session_options = &options_; - execution_state_.reset( - new SimpleGraphExecutionState(graph.library(), options)); + // TODO(mrry,suharshs): We explicitly copy `graph` so that + // `MakeForBaseGraph()` can take ownership of its + // contents. Previously this happened implicitly in calls to the + // `SimpleGraphExecutionState`. Other sessions call + // `MakeForBaseGraph` in such a way that we can destructively read + // the passed-in `GraphDef`. In principle we could do the same here, + // with a wider refactoring; we might revise the direct session so + // that it copies the graph fewer times. + GraphDef temp(graph); + TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( + &temp, options, &execution_state_)); + graph_created_ = true; + *out_already_initialized = false; + return Status::OK(); } Status DirectSession::Create(const GraphDef& graph) { - mutex_lock l(graph_def_lock_); - if (graph_created_) { - return errors::AlreadyExists( - "A Graph has already been created for this session."); + if (graph.node_size() > 0) { + mutex_lock l(graph_def_lock_); + if (graph_created_) { + return errors::AlreadyExists( + "A Graph has already been created for this session."); + } + return ExtendLocked(graph); } - return ExtendLocked(graph); + return Status::OK(); } Status DirectSession::Extend(const GraphDef& graph) { @@ -316,12 +335,16 @@ Status DirectSession::Extend(const GraphDef& graph) { } Status DirectSession::ExtendLocked(const GraphDef& graph) { - MaybeInitializeExecutionState(graph); - std::unique_ptr<SimpleGraphExecutionState> state; - TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); - execution_state_.swap(state); - - graph_created_ = true; // In case this is first call + bool already_initialized; + // If this is the first call, we can initialize the execution state + // with `graph` and do not need to call `Extend()`. + TF_RETURN_IF_ERROR( + MaybeInitializeExecutionState(graph, &already_initialized)); + if (already_initialized) { + std::unique_ptr<SimpleGraphExecutionState> state; + TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); + execution_state_.swap(state); + } return Status::OK(); } @@ -949,7 +972,7 @@ Status DirectSession::GetOrCreateExecutors( } Status DirectSession::CreateGraphs( - const BuildGraphOptions& options, + const BuildGraphOptions& subgraph_options, std::unordered_map<string, std::unique_ptr<Graph>>* outputs, std::unique_ptr<FunctionLibraryDefinition>* flib_def, RunStateArgs* run_state_args) { @@ -960,23 +983,23 @@ Status DirectSession::CreateGraphs( SimpleGraphExecutionState* execution_state = nullptr; if (options_.config.graph_options().place_pruned_graph()) { // Because we are placing pruned graphs, we need to create a - // new SimpleGraphExecutorState for every new unseen graph, + // new SimpleGraphExecutionState for every new unseen graph, // and then place it. SimpleGraphExecutionStateOptions prune_options; prune_options.device_set = &device_set_; prune_options.session_options = &options_; - temp_exec_state_holder.reset(new SimpleGraphExecutionState( - execution_state_->original_graph_def().library(), prune_options)); - temp_exec_state_holder->SetStatefulPlacements(stateful_placements_); - - TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend( - execution_state_->original_graph_def(), &temp_exec_state_holder)); + prune_options.stateful_placements = stateful_placements_; + TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForPrunedGraph( + execution_state_->original_graph_def().library(), prune_options, + execution_state_->original_graph_def(), subgraph_options, + &temp_exec_state_holder, &client_graph)); execution_state = temp_exec_state_holder.get(); } else { execution_state = execution_state_.get(); + TF_RETURN_IF_ERROR( + execution_state->BuildGraph(subgraph_options, &client_graph)); } - TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &client_graph)); auto current_stateful_placements = execution_state->GetStatefulPlacements(); // Update our current state based on the execution_state's // placements. If there are any mismatches for a node, |