diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/context.cc')
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.cc | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 70208fb6d1..5e0f0a45f8 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -17,8 +17,20 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { +namespace { + +bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { + bool val; + if (ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) { + return val; + } + return default_val; +} + +} // namespace EagerContext::EagerContext(const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, @@ -34,8 +46,16 @@ EagerContext::EagerContext(const SessionOptions& opts, local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), - async_default_(async) { + async_default_(async), + use_send_tensor_rpc_(false) { InitDeviceMapAndAsync(); + if (opts.config.inter_op_parallelism_threads() > 0) { + runner_ = [this](std::function<void()> closure) { + this->thread_pool_->Schedule(closure); + }; + } else { + runner_ = [](std::function<void()> closure) { closure(); }; + } } #ifndef __ANDROID__ @@ -59,7 +79,9 @@ EagerContext::EagerContext( remote_device_manager_(std::move(remote_device_manager)), server_(std::move(server)), remote_eager_workers_(std::move(remote_eager_workers)), - remote_contexts_(remote_contexts) { + remote_contexts_(remote_contexts), + use_send_tensor_rpc_( + ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false)) { InitDeviceMapAndAsync(); } #endif |