diff options
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 57 |
1 files changed, 47 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index af5d5b17e7..458e133b68 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/run_handler.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" @@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool, #endif // __ANDROID__ } +static RunHandlerPool* GetOrCreateRunHandlerPool( + const SessionOptions& options) { + static RunHandlerPool* pool = + new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options)); + return pool; +} + +bool DirectSession::ShouldUseRunHandlerPool() const { + if (options_.config.session_inter_op_thread_pool_size() > 0 || + options_.config.use_per_session_threads()) { + return false; + } + return true; +} + DirectSession::DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, DirectSessionFactory* const factory) @@ -363,7 +379,7 @@ Status DirectSession::MaybeInitializeExecutionState( Status DirectSession::Create(const GraphDef& graph) { TF_RETURN_IF_ERROR(init_error_); if (graph.node_size() > 0) { - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); if (graph_created_) { return errors::AlreadyExists( "A Graph has already been created for this session."); @@ -375,7 +391,7 @@ Status DirectSession::Create(const GraphDef& graph) { Status DirectSession::Extend(const GraphDef& graph) { TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); return ExtendLocked(graph); } @@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } } - Executor::Args::Runner default_runner = [this, - pool](Executor::Args::Closure c) { - SchedClosure(pool, std::move(c)); - }; + std::unique_ptr<RunHandler> handler; + if (ShouldUseRunHandlerPool() && + run_options.experimental().use_run_handler_pool()) { + // Non-null only when a global inter-op pool is used. + VLOG(1) << "Using RunHandler to scheduler inter-op closures."; + handler = GetOrCreateRunHandlerPool(options_)->Get(); + } + auto* handler_ptr = handler.get(); + + Executor::Args::Runner default_runner = nullptr; + + if (pool == nullptr) { + default_runner = [](Executor::Args::Closure c) { c(); }; + } else if (handler_ptr != nullptr) { + default_runner = [handler_ptr](Executor::Args::Closure c) { + handler_ptr->ScheduleInterOpClosure(std::move(c)); + }; + } else { + default_runner = [this, pool](Executor::Args::Closure c) { + SchedClosure(pool, std::move(c)); + }; + } + for (const auto& item : executors_and_keys->items) { - // TODO(zhengxq): support partial run. - // TODO(zhengxq): if the device picks its own threadpool, we need to assign + // TODO(azaks): support partial run. + // TODO(azaks): if the device picks its own threadpool, we need to assign // less threads to the main compute pool by default. thread::ThreadPool* device_thread_pool = item.device->tensorflow_device_thread_pool(); + // TODO(crk): Investigate usage of RunHandlerPool when using device specific + // thread pool(s). if (!device_thread_pool) { args.runner = default_runner; } else { @@ -1172,7 +1209,7 @@ Status DirectSession::CreateExecutors( int graph_def_version; { - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); graph_def_version = execution_state_->original_graph_def().versions().producer(); } @@ -1400,7 +1437,7 @@ Status DirectSession::CreateGraphs( std::unique_ptr<FunctionLibraryDefinition>* flib_def, RunStateArgs* run_state_args, DataTypeVector* input_types, DataTypeVector* output_types, int64* collective_graph_key) { - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); std::unique_ptr<ClientGraph> client_graph; std::unique_ptr<GraphExecutionState> temp_exec_state_holder; |