diff options
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/gpu/process_state.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/gpu/process_state.h | 19 |
3 files changed, 15 insertions, 19 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index fbb886e538..82ffcdcc7b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -174,6 +174,8 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, cpu_allocator_(cpu_allocator), gpu_id_(gpu_id), sync_every_op_(sync_every_op) { + ProcessState::singleton()->EnableGPUDevice(); + gpu::StreamExecutor* executor = GPUMachineManager()->ExecutorForDevice(gpu_id_).ValueOrDie(); if (!executor) { diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index 17995f6b29..484547f922 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -66,7 +66,7 @@ ProcessState* ProcessState::instance_ = nullptr; return instance_; } -ProcessState::ProcessState() : gpu_count_(0) { +ProcessState::ProcessState() : gpu_device_enabled_(false) { CHECK(instance_ == nullptr); instance_ = this; } @@ -93,15 +93,6 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { return MemDesc(); } -void ProcessState::SetGPUCount(int c) { - CHECK(gpu_count_ == 0 || gpu_count_ == c) - << "Cannot call SetGPUCount with a non-zero value " - << "not equal to prior set value."; - gpu_count_ = c; -} - -int ProcessState::GPUCount() const { return gpu_count_; } - Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes, const string& allocator_type) { #if GOOGLE_CUDA @@ -187,7 +178,7 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) { } Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) { - if (gpu_count_ == 0 || !FLAGS_brain_mem_reg_cuda_dma) { + if (!HasGPUDevice() || !FLAGS_brain_mem_reg_cuda_dma) { return GetCPUAllocator(numa_node); } // Although we're temporarily ignoring numa_node, check for legality. diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h index 42f4c8417b..eabc281260 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.h +++ b/tensorflow/core/common_runtime/gpu/process_state.h @@ -53,14 +53,17 @@ class ProcessState { string DebugString(); }; - // Records the number of GPUs available in the local process. - // It is a fatal error to call this with a value != to the value - // in a prior call. - void SetGPUCount(int c); + // Query whether any GPU device has been created so far. + // Disable thread safety analysis since a race is benign here. + bool HasGPUDevice() const NO_THREAD_SAFETY_ANALYSIS { + return gpu_device_enabled_; + } - // Returns number of GPUs available in local process, as set by - // SetGPUCount(); Returns 0 if SetGPUCount has not been called. - int GPUCount() const; + // Set the flag to indicate a GPU device has been created. + // Disable thread safety analysis since a race is benign here. + void EnableGPUDevice() NO_THREAD_SAFETY_ANALYSIS { + gpu_device_enabled_ = true; + } // Returns what we know about the memory at ptr. // If we know nothing, it's called CPU 0 with no other attributes. @@ -109,9 +112,9 @@ class ProcessState { ProcessState(); static ProcessState* instance_; + bool gpu_device_enabled_; mutex mu_; - int gpu_count_; std::vector<PoolAllocator*> cpu_allocators_ GUARDED_BY(mu_); std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_); |