aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2016-09-08 01:03:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-08 02:17:40 -0700
commit68d90864e7156c29dbf72697979bdd5d3174ac2d (patch)
tree58b81a96c8aa67305a7c2f2ed88cb4e60edd68cb
parentc0169dc34a99d8541bd420ddf7b73e1e37dfbf19 (diff)
Clean up locking in DirectSession.
Change: 132530959
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc54
-rw-r--r--tensorflow/core/common_runtime/direct_session.h16
-rw-r--r--tensorflow/core/debug/debug_gateway_test.cc2
3 files changed, 31 insertions, 41 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 4c90226231..0344bd9c97 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -222,7 +222,6 @@ DirectSession::DirectSession(const SessionOptions& options,
device_mgr_(device_mgr),
factory_(factory),
cancellation_manager_(new CancellationManager()),
- closed_(false),
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
if (options_.config.session_inter_op_thread_pool_size() > 0) {
for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
@@ -808,10 +807,8 @@ Status DirectSession::GetOrCreateExecutors(
run_state_args->is_partial_run);
// Set the handle.
- {
- mutex_lock l(mu_);
- run_state_args->handle = strings::StrCat(key, ";", name_counter_++);
- }
+ run_state_args->handle =
+ strings::StrCat(key, ";", handle_name_counter_.fetch_add(1));
// See if we already have the executors for this run.
{
@@ -964,10 +961,7 @@ Status DirectSession::CreateGraphs(
prune_options.session_options = &options_;
temp_exec_state_holder.reset(new SimpleGraphExecutionState(
execution_state_->original_graph_def().library(), prune_options));
- {
- mutex_lock l(mu_);
- temp_exec_state_holder->SetStatefulPlacements(stateful_placements_);
- }
+ temp_exec_state_holder->SetStatefulPlacements(stateful_placements_);
TF_RETURN_IF_ERROR(temp_exec_state_holder->Extend(
execution_state_->original_graph_def(), &temp_exec_state_holder));
@@ -977,29 +971,26 @@ Status DirectSession::CreateGraphs(
}
TF_RETURN_IF_ERROR(execution_state->BuildGraph(options, &client_graph));
- {
- auto current_stateful_placements = execution_state->GetStatefulPlacements();
- mutex_lock l(mu_);
- // Update our current state based on the execution_state's
- // placements. If there are any mismatches for a node,
- // we should fail, as this should never happen.
- for (auto placement_pair : current_stateful_placements) {
- const string& node_name = placement_pair.first;
- const string& placement = placement_pair.second;
- auto iter = stateful_placements_.find(node_name);
- if (iter == stateful_placements_.end()) {
- stateful_placements_.insert(std::make_pair(node_name, placement));
- } else if (iter->second != placement) {
- return errors::Internal(
- "Stateful placement mismatch. "
- "Current assignment of ",
- node_name, " to ", iter->second, " does not match ", placement);
- }
+ auto current_stateful_placements = execution_state->GetStatefulPlacements();
+ // Update our current state based on the execution_state's
+ // placements. If there are any mismatches for a node,
+ // we should fail, as this should never happen.
+ for (auto placement_pair : current_stateful_placements) {
+ const string& node_name = placement_pair.first;
+ const string& placement = placement_pair.second;
+ auto iter = stateful_placements_.find(node_name);
+ if (iter == stateful_placements_.end()) {
+ stateful_placements_.insert(std::make_pair(node_name, placement));
+ } else if (iter->second != placement) {
+ return errors::Internal(
+ "Stateful placement mismatch. "
+ "Current assignment of ",
+ node_name, " to ", iter->second, " does not match ", placement);
}
-
- stateful_placements_ = execution_state->GetStatefulPlacements();
}
+ stateful_placements_ = execution_state->GetStatefulPlacements();
+
// Remember the graph in run state if this is a partial run.
if (run_state_args->is_partial_run) {
run_state_args->graph.reset(new Graph(flib_def_.get()));
@@ -1012,8 +1003,7 @@ Status DirectSession::CreateGraphs(
return node->assigned_device_name();
};
popts.new_name = [this](const string& prefix) {
- mutex_lock l(mu_);
- return strings::StrCat(prefix, "/_", name_counter_++);
+ return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
};
popts.get_incarnation = [](const string& name) {
// The direct session does not have changing incarnation numbers.
@@ -1089,7 +1079,7 @@ Status DirectSession::CreateGraphs(
::tensorflow::Status DirectSession::Close() {
cancellation_manager_->StartCancel();
{
- mutex_lock l(mu_);
+ mutex_lock l(closed_lock_);
if (closed_) return ::tensorflow::Status::OK();
closed_ = true;
}
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 8681d8fb7c..8e757e9273 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -211,7 +211,7 @@ class DirectSession : public Session {
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
::tensorflow::Status CheckNotClosed() {
- mutex_lock l(mu_);
+ mutex_lock l(closed_lock_);
if (closed_) return errors::Cancelled("Session has been closed.");
return ::tensorflow::Status::OK();
}
@@ -253,14 +253,12 @@ class DirectSession : public Session {
DirectSessionFactory* const factory_; // not owned
CancellationManager* cancellation_manager_;
- // Saves and restores device placements for stateful nodes.
- mutex mu_;
-
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
// nodes can not be moved to a different device. Maps node names to
// device names.
- std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_);
+ std::unordered_map<string, string> stateful_placements_
+ GUARDED_BY(graph_def_lock_);
// Execution_state; used when placing the entire graph.
std::unique_ptr<SimpleGraphExecutionState> execution_state_
@@ -272,10 +270,12 @@ class DirectSession : public Session {
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
// true if the Session has been Closed.
- bool closed_ GUARDED_BY(mu_);
+ mutex closed_lock_;
+ bool closed_ GUARDED_BY(closed_lock_) = false;
- // For generating unique names.
- int64 name_counter_ GUARDED_BY(mu_) = 0;
+ // For generating unique names for this session instance.
+ std::atomic<int64> edge_name_counter_ = {0};
+ std::atomic<int64> handle_name_counter_ = {0};
// For generating step ids that are unique across all sessions.
static std::atomic_int_fast64_t step_id_counter_;
diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc
index f7897d7764..0e5a705a46 100644
--- a/tensorflow/core/debug/debug_gateway_test.cc
+++ b/tensorflow/core/debug/debug_gateway_test.cc
@@ -763,7 +763,7 @@ TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) {
run_opts.set_output_partition_graphs(true);
// This is the name of the boolean tensor fed as pred to the Switch node.
// On GPU, this edge is HOST_MEMORY.
- const string watched_tensor = strings::StrCat(pred_node_name_, "/_2");
+ const string watched_tensor = strings::StrCat(pred_node_name_, "/_1");
const string debug_identity = "DebugIdentity";
DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();