diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/context.h')
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 31 |
1 files changed, 24 insertions, 7 deletions
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 21c5bdf8e9..15eeaa8066 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #endif +#include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -65,10 +66,17 @@ enum ContextDevicePlacementPolicy { class EagerContext { public: - explicit EagerContext(const SessionOptions& opts, - ContextDevicePlacementPolicy default_policy, bool async, - std::unique_ptr<DeviceMgr> device_mgr, - Rendezvous* rendezvous); + // TODO: remove this constructor once we migrate all callers to the next one. + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + std::unique_ptr<const DeviceMgr> device_mgr, + Rendezvous* rendezvous); + + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + const DeviceMgr* device_mgr, bool device_mgr_owned, + Rendezvous* rendezvous); + ~EagerContext(); // Returns the function library runtime for the given device. @@ -93,6 +101,9 @@ class EagerContext { // TODO(apassos) make this return a constant reference std::vector<Device*>* devices() { return &devices_; } + const std::vector<DeviceType>& prioritized_device_type_list() { + return prioritized_device_type_list_; + } // Clears the kernel caches. void ClearCaches(); @@ -131,6 +142,7 @@ class EagerContext { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() { return log_device_placement_; } + bool LogMemory() { return log_memory_; } Rendezvous* GetRendezvous() { return rendezvous_; } @@ -190,6 +202,7 @@ class EagerContext { // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used // instead (which in-turn use WorkerService.RecvTensor RPCs). bool UseSendTensorRPC() { return use_send_tensor_rpc_; } + bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; } private: void InitDeviceMapAndAsync(); @@ -204,11 +217,13 @@ class EagerContext { thread_local_policies_ GUARDED_BY(policy_map_mu_); // Only one of the below is set. - std::unique_ptr<DeviceMgr> local_device_manager_; - DeviceMgr* local_unowned_device_manager_; + std::unique_ptr<const DeviceMgr> local_device_manager_; + const DeviceMgr* local_unowned_device_manager_; + std::unique_ptr<DeviceMgr> remote_device_manager_; // Devices owned by device_manager std::vector<Device*> devices_; + std::vector<DeviceType> prioritized_device_type_list_; // All devices are not owned. gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_; Rendezvous* rendezvous_; @@ -249,11 +264,12 @@ class EagerContext { std::unordered_map<std::thread::id, bool> thread_local_async_ GUARDED_BY(async_map_mu_); + const bool log_memory_; + Env* const env_; #ifndef __ANDROID__ void CloseRemoteContexts(); - std::unique_ptr<DeviceMgr> remote_device_manager_; // The server_ is not const since we release it when the context is destroyed. // Therefore the server_ object is not marked as const (even though it should @@ -278,6 +294,7 @@ class EagerContext { #endif bool use_send_tensor_rpc_; + const bool pin_small_ops_to_cpu_; }; } // namespace tensorflow |