aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/context.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/eager/context.h')
-rw-r--r--tensorflow/core/common_runtime/eager/context.h13
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 864f514a19..4a180e074d 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -105,6 +105,8 @@ class EagerContext {
EagerExecutor* Executor() { return &executor_; }
+ std::function<void(std::function<void()>)>* runner() { return &runner_; }
+
// Sets whether this thread should run in synchronous or asynchronous mode.
Status SetAsyncForThread(bool async);
@@ -180,6 +182,11 @@ class EagerContext {
#ifndef __ANDROID__
Status GetClientAndContextID(Device* device, eager::EagerClient** client,
uint64* context_id);
+
+ // If true, then tensors should be shipped across processes via the
+ // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
+ // instead (which in-turn use WorkerService.RecvTensor RPCs.
+ bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
#endif
private:
void InitDeviceMapAndAsync();
@@ -214,6 +221,8 @@ class EagerContext {
// session->devices[i].
const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ std::function<void(std::function<void()>)> runner_;
+
mutex cache_mu_;
std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_
GUARDED_BY(cache_mu_);
@@ -235,16 +244,18 @@ class EagerContext {
const std::unique_ptr<DeviceMgr> remote_device_manager_;
+#ifndef __ANDROID__
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
// be).
-#ifndef __ANDROID__
std::unique_ptr<ServerInterface> server_;
const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
const gtl::FlatMap<string, uint64> remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
device_to_client_cache_;
+
+ const bool use_send_tensor_rpc_;
#endif
};