diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2016-09-08 01:03:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-08 02:17:40 -0700 |
commit | 68d90864e7156c29dbf72697979bdd5d3174ac2d (patch) | |
tree | 58b81a96c8aa67305a7c2f2ed88cb4e60edd68cb | |
parent | c0169dc34a99d8541bd420ddf7b73e1e37dfbf19 (diff) |
Clean up locking in DirectSession.
Change: 132530959
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 54 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.h | 16 | ||||
-rw-r--r-- | tensorflow/core/debug/debug_gateway_test.cc | 2 |
3 files changed, 31 insertions, 41 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 4c90226231..0344bd9c97 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -222,7 +222,6 @@ DirectSession::DirectSession(const SessionOptions& options, device_mgr_(device_mgr), factory_(factory), cancellation_manager_(new CancellationManager()), - closed_(false), operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) { if (options_.config.session_inter_op_thread_pool_size() > 0) { for (int i = 0; i < options_.config.session_inter_op_thread_pool_size(); @@ -808,10 +807,8 @@ Status DirectSession::GetOrCreateExecutors( run_state_args->is_partial_run); // Set the handle. - { - mutex_lock l(mu_); - run_state_args->handle = strings::StrCat(key, ";", name_counter_++); - } + run_state_args->handle = + strings::StrCat(key, ";", handle_name_counter_.fetch_add(1)); // See if we already have the executors for this run. { @@ -964,10 +961,7 @@ Status DirectSession::CreateGraphs( prune_options.session_options = &options_; temp_exec_state_holder.reset(new SimpleGraphExecutionState( execution_state_->original_graph_def().library(), prune_options)); - { - mutex_lock l(mu_); - temp_exec_state_holder->SetStatefulPlacements(stateful_placements_); - } + temp_exec_state_holder->SetStatefulPlacements(stateful_placements_); TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend( execution_state_->original_graph_def(), &temp_exec_state_holder)); @@ -977,29 +971,26 @@ Status DirectSession::CreateGraphs( } TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &client_graph)); - { - auto current_stateful_placements = execution_state->GetStatefulPlacements(); - mutex_lock l(mu_); - // Update our current state based on the execution_state's - // placements. If there are any mismatches for a node, - // we should fail, as this should never happen. - for (auto placement_pair : current_stateful_placements) { - const string& node_name = placement_pair.first; - const string& placement = placement_pair.second; - auto iter = stateful_placements_.find(node_name); - if (iter == stateful_placements_.end()) { - stateful_placements_.insert(std::make_pair(node_name, placement)); - } else if (iter->second != placement) { - return errors::Internal( - "Stateful placement mismatch. " - "Current assignment of ", - node_name, " to ", iter->second, " does not match ", placement); - } + auto current_stateful_placements = execution_state->GetStatefulPlacements(); + // Update our current state based on the execution_state's + // placements. If there are any mismatches for a node, + // we should fail, as this should never happen. + for (auto placement_pair : current_stateful_placements) { + const string& node_name = placement_pair.first; + const string& placement = placement_pair.second; + auto iter = stateful_placements_.find(node_name); + if (iter == stateful_placements_.end()) { + stateful_placements_.insert(std::make_pair(node_name, placement)); + } else if (iter->second != placement) { + return errors::Internal( + "Stateful placement mismatch. " + "Current assignment of ", + node_name, " to ", iter->second, " does not match ", placement); } - - stateful_placements_ = execution_state->GetStatefulPlacements(); } + stateful_placements_ = execution_state->GetStatefulPlacements(); + // Remember the graph in run state if this is a partial run. if (run_state_args->is_partial_run) { run_state_args->graph.reset(new Graph(flib_def_.get())); @@ -1012,8 +1003,7 @@ Status DirectSession::CreateGraphs( return node->assigned_device_name(); }; popts.new_name = [this](const string& prefix) { - mutex_lock l(mu_); - return strings::StrCat(prefix, "/_", name_counter_++); + return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1)); }; popts.get_incarnation = [](const string& name) { // The direct session does not have changing incarnation numbers. @@ -1089,7 +1079,7 @@ Status DirectSession::CreateGraphs( ::tensorflow::Status DirectSession::Close() { cancellation_manager_->StartCancel(); { - mutex_lock l(mu_); + mutex_lock l(closed_lock_); if (closed_) return ::tensorflow::Status::OK(); closed_ = true; } diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 8681d8fb7c..8e757e9273 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -211,7 +211,7 @@ class DirectSession : public Session { void WaitForNotification(RunState* run_state, int64 timeout_in_ms); ::tensorflow::Status CheckNotClosed() { - mutex_lock l(mu_); + mutex_lock l(closed_lock_); if (closed_) return errors::Cancelled("Session has been closed."); return ::tensorflow::Status::OK(); } @@ -253,14 +253,12 @@ class DirectSession : public Session { DirectSessionFactory* const factory_; // not owned CancellationManager* cancellation_manager_; - // Saves and restores device placements for stateful nodes. - mutex mu_; - // Map of placed stateful nodes, i.e. nodes for which is_stateful() // is true, such as "params" and "queue" nodes. Once placed these // nodes can not be moved to a different device. Maps node names to // device names. - std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_); + std::unordered_map<string, string> stateful_placements_ + GUARDED_BY(graph_def_lock_); // Execution_state; used when placing the entire graph. std::unique_ptr<SimpleGraphExecutionState> execution_state_ @@ -272,10 +270,12 @@ class DirectSession : public Session { std::unique_ptr<FunctionLibraryDefinition> flib_def_; // true if the Session has been Closed. - bool closed_ GUARDED_BY(mu_); + mutex closed_lock_; + bool closed_ GUARDED_BY(closed_lock_) = false; - // For generating unique names. - int64 name_counter_ GUARDED_BY(mu_) = 0; + // For generating unique names for this session instance. + std::atomic<int64> edge_name_counter_ = {0}; + std::atomic<int64> handle_name_counter_ = {0}; // For generating step ids that are unique across all sessions. static std::atomic_int_fast64_t step_id_counter_; diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc index f7897d7764..0e5a705a46 100644 --- a/tensorflow/core/debug/debug_gateway_test.cc +++ b/tensorflow/core/debug/debug_gateway_test.cc @@ -763,7 +763,7 @@ TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) { run_opts.set_output_partition_graphs(true); // This is the name of the boolean tensor fed as pred to the Switch node. // On GPU, this edge is HOST_MEMORY. - const string watched_tensor = strings::StrCat(pred_node_name_, "/_2"); + const string watched_tensor = strings::StrCat(pred_node_name_, "/_1"); const string debug_identity = "DebugIdentity"; DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts(); |