diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/context.h')
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 864f514a19..4a180e074d 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -105,6 +105,8 @@ class EagerContext { EagerExecutor* Executor() { return &executor_; } + std::function<void(std::function<void()>)>* runner() { return &runner_; } + // Sets whether this thread should run in synchronous or asynchronous mode. Status SetAsyncForThread(bool async); @@ -180,6 +182,11 @@ class EagerContext { #ifndef __ANDROID__ Status GetClientAndContextID(Device* device, eager::EagerClient** client, uint64* context_id); + + // If true, then tensors should be shipped across processes via the + // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used + // instead (which in-turn use WorkerService.RecvTensor RPCs. + bool UseSendTensorRPC() { return use_send_tensor_rpc_; } #endif private: void InitDeviceMapAndAsync(); @@ -214,6 +221,8 @@ class EagerContext { // session->devices[i]. const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + std::function<void(std::function<void()>)> runner_; + mutex cache_mu_; std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_ GUARDED_BY(cache_mu_); @@ -235,16 +244,18 @@ class EagerContext { const std::unique_ptr<DeviceMgr> remote_device_manager_; +#ifndef __ANDROID__ // The server_ is not const since we release it when the context is destroyed. // Therefore the server_ object is not marked as const (even though it should // be). -#ifndef __ANDROID__ std::unique_ptr<ServerInterface> server_; const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; const gtl::FlatMap<string, uint64> remote_contexts_; gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>> device_to_client_cache_; + + const bool use_send_tensor_rpc_; #endif }; |