aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/context.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/eager/context.cc')
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc26
1 files changed, 24 insertions, 2 deletions
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 70208fb6d1..5e0f0a45f8 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -17,8 +17,20 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
+namespace {
+
+bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
+ bool val;
+ if (ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
+ return val;
+ }
+ return default_val;
+}
+
+} // namespace
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
@@ -34,8 +46,16 @@ EagerContext::EagerContext(const SessionOptions& opts,
local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
- async_default_(async) {
+ async_default_(async),
+ use_send_tensor_rpc_(false) {
InitDeviceMapAndAsync();
+ if (opts.config.inter_op_parallelism_threads() > 0) {
+ runner_ = [this](std::function<void()> closure) {
+ this->thread_pool_->Schedule(closure);
+ };
+ } else {
+ runner_ = [](std::function<void()> closure) { closure(); };
+ }
}
#ifndef __ANDROID__
@@ -59,7 +79,9 @@ EagerContext::EagerContext(
remote_device_manager_(std::move(remote_device_manager)),
server_(std::move(server)),
remote_eager_workers_(std::move(remote_eager_workers)),
- remote_contexts_(remote_contexts) {
+ remote_contexts_(remote_contexts),
+ use_send_tensor_rpc_(
+ ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false)) {
InitDeviceMapAndAsync();
}
#endif