diff options
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_device_factory.cc')
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device_factory.cc | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc index 9a000749c6..e1aaf95df6 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_device.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/threadpool_device.h" namespace tensorflow { @@ -40,9 +40,10 @@ class GPUDevice : public BaseGPUDevice { } Allocator* GetAllocator(AllocatorAttributes attr) override { + CHECK(cpu_allocator_) << "bad place 1"; if (attr.on_host()) { if (attr.gpu_compatible() || force_gpu_compatible_) { - ProcessState* ps = ProcessState::singleton(); + GPUProcessState* ps = GPUProcessState::singleton(); return ps->GetCUDAHostAllocator(0); } else { return cpu_allocator_; @@ -90,7 +91,7 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice { ~GPUCompatibleCPUDevice() override {} Allocator* GetAllocator(AllocatorAttributes attr) override { - ProcessState* ps = ProcessState::singleton(); + GPUProcessState* ps = GPUProcessState::singleton(); if (attr.gpu_compatible() || force_gpu_compatible_) { return ps->GetCUDAHostAllocator(0); } else { |