diff options
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_device_factory.cc')
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device_factory.cc | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc index 94143a55d5..d9fa5a6b96 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc @@ -31,12 +31,16 @@ class GPUDevice : public BaseGPUDevice { Allocator* cpu_allocator) : BaseGPUDevice(options, name, memory_limit, locality, gpu_id, physical_device_desc, gpu_allocator, cpu_allocator, - false /* sync every op */, 1 /* max_streams */) {} + false /* sync every op */, 1 /* max_streams */) { + if (options.config.has_gpu_options()) { + force_gpu_compatible_ = options.config.gpu_options.force_gpu_compatible; + } + } Allocator* GetAllocator(AllocatorAttributes attr) override { if (attr.on_host()) { ProcessState* ps = ProcessState::singleton(); - if (attr.gpu_compatible()) { + if (attr.gpu_compatible() || force_gpu_compatible_) { return ps->GetCUDAHostAllocator(0); } else { return cpu_allocator_; @@ -71,12 +75,16 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice { GPUCompatibleCPUDevice(const SessionOptions& options, const string& name, Bytes memory_limit, const DeviceLocality& locality, Allocator* allocator) - : ThreadPoolDevice(options, name, memory_limit, locality, allocator) {} + : ThreadPoolDevice(options, name, memory_limit, locality, allocator) { + if (options.config.has_gpu_options()) { + force_gpu_compatible_ = options.config.gpu_options.force_gpu_compatible; + } + } ~GPUCompatibleCPUDevice() override {} Allocator* GetAllocator(AllocatorAttributes attr) override { ProcessState* ps = ProcessState::singleton(); - if (attr.gpu_compatible()) { + if (attr.gpu_compatible() || force_gpu_compatible_) { return ps->GetCUDAHostAllocator(0); } else { // Call the parent's implementation. |