aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/direct_session.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc57
1 files changed, 47 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index af5d5b17e7..458e133b68 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/run_handler.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
#endif // __ANDROID__
}
+static RunHandlerPool* GetOrCreateRunHandlerPool(
+ const SessionOptions& options) {
+ static RunHandlerPool* pool =
+ new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
+ return pool;
+}
+
+bool DirectSession::ShouldUseRunHandlerPool() const {
+ if (options_.config.session_inter_op_thread_pool_size() > 0 ||
+ options_.config.use_per_session_threads()) {
+ return false;
+ }
+ return true;
+}
+
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr,
DirectSessionFactory* const factory)
@@ -363,7 +379,7 @@ Status DirectSession::MaybeInitializeExecutionState(
Status DirectSession::Create(const GraphDef& graph) {
TF_RETURN_IF_ERROR(init_error_);
if (graph.node_size() > 0) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (graph_created_) {
return errors::AlreadyExists(
"A Graph has already been created for this session.");
@@ -375,7 +391,7 @@ Status DirectSession::Create(const GraphDef& graph) {
Status DirectSession::Extend(const GraphDef& graph) {
TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
return ExtendLocked(graph);
}
@@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- Executor::Args::Runner default_runner = [this,
- pool](Executor::Args::Closure c) {
- SchedClosure(pool, std::move(c));
- };
+ std::unique_ptr<RunHandler> handler;
+ if (ShouldUseRunHandlerPool() &&
+ run_options.experimental().use_run_handler_pool()) {
+ // Non-null only when a global inter-op pool is used.
+ VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
+ handler = GetOrCreateRunHandlerPool(options_)->Get();
+ }
+ auto* handler_ptr = handler.get();
+
+ Executor::Args::Runner default_runner = nullptr;
+
+ if (pool == nullptr) {
+ default_runner = [](Executor::Args::Closure c) { c(); };
+ } else if (handler_ptr != nullptr) {
+ default_runner = [handler_ptr](Executor::Args::Closure c) {
+ handler_ptr->ScheduleInterOpClosure(std::move(c));
+ };
+ } else {
+ default_runner = [this, pool](Executor::Args::Closure c) {
+ SchedClosure(pool, std::move(c));
+ };
+ }
+
for (const auto& item : executors_and_keys->items) {
- // TODO(zhengxq): support partial run.
- // TODO(zhengxq): if the device picks its own threadpool, we need to assign
+ // TODO(azaks): support partial run.
+ // TODO(azaks): if the device picks its own threadpool, we need to assign
// less threads to the main compute pool by default.
thread::ThreadPool* device_thread_pool =
item.device->tensorflow_device_thread_pool();
+ // TODO(crk): Investigate usage of RunHandlerPool when using device specific
+ // thread pool(s).
if (!device_thread_pool) {
args.runner = default_runner;
} else {
@@ -1172,7 +1209,7 @@ Status DirectSession::CreateExecutors(
int graph_def_version;
{
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
graph_def_version =
execution_state_->original_graph_def().versions().producer();
}
@@ -1400,7 +1437,7 @@ Status DirectSession::CreateGraphs(
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types, int64* collective_graph_key) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
std::unique_ptr<ClientGraph> client_graph;
std::unique_ptr<GraphExecutionState> temp_exec_state_holder;