aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc13
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.h19
3 files changed, 15 insertions, 19 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index fbb886e538..82ffcdcc7b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -174,6 +174,8 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
cpu_allocator_(cpu_allocator),
gpu_id_(gpu_id),
sync_every_op_(sync_every_op) {
+ ProcessState::singleton()->EnableGPUDevice();
+
gpu::StreamExecutor* executor =
GPUMachineManager()->ExecutorForDevice(gpu_id_).ValueOrDie();
if (!executor) {
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 17995f6b29..484547f922 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -66,7 +66,7 @@ ProcessState* ProcessState::instance_ = nullptr;
return instance_;
}
-ProcessState::ProcessState() : gpu_count_(0) {
+ProcessState::ProcessState() : gpu_device_enabled_(false) {
CHECK(instance_ == nullptr);
instance_ = this;
}
@@ -93,15 +93,6 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
return MemDesc();
}
-void ProcessState::SetGPUCount(int c) {
- CHECK(gpu_count_ == 0 || gpu_count_ == c)
- << "Cannot call SetGPUCount with a non-zero value "
- << "not equal to prior set value.";
- gpu_count_ = c;
-}
-
-int ProcessState::GPUCount() const { return gpu_count_; }
-
Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes,
const string& allocator_type) {
#if GOOGLE_CUDA
@@ -187,7 +178,7 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) {
}
Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
- if (gpu_count_ == 0 || !FLAGS_brain_mem_reg_cuda_dma) {
+ if (!HasGPUDevice() || !FLAGS_brain_mem_reg_cuda_dma) {
return GetCPUAllocator(numa_node);
}
// Although we're temporarily ignoring numa_node, check for legality.
diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h
index 42f4c8417b..eabc281260 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.h
+++ b/tensorflow/core/common_runtime/gpu/process_state.h
@@ -53,14 +53,17 @@ class ProcessState {
string DebugString();
};
- // Records the number of GPUs available in the local process.
- // It is a fatal error to call this with a value != to the value
- // in a prior call.
- void SetGPUCount(int c);
+ // Query whether any GPU device has been created so far.
+ // Disable thread safety analysis since a race is benign here.
+ bool HasGPUDevice() const NO_THREAD_SAFETY_ANALYSIS {
+ return gpu_device_enabled_;
+ }
- // Returns number of GPUs available in local process, as set by
- // SetGPUCount(); Returns 0 if SetGPUCount has not been called.
- int GPUCount() const;
+ // Set the flag to indicate a GPU device has been created.
+ // Disable thread safety analysis since a race is benign here.
+ void EnableGPUDevice() NO_THREAD_SAFETY_ANALYSIS {
+ gpu_device_enabled_ = true;
+ }
// Returns what we know about the memory at ptr.
// If we know nothing, it's called CPU 0 with no other attributes.
@@ -109,9 +112,9 @@ class ProcessState {
ProcessState();
static ProcessState* instance_;
+ bool gpu_device_enabled_;
mutex mu_;
- int gpu_count_;
std::vector<PoolAllocator*> cpu_allocators_ GUARDED_BY(mu_);
std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);