diff options
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_device.cc')
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device.cc | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 520c2f9c34..3292ef2f62 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -36,11 +36,12 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/visitable_allocator.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" @@ -224,6 +225,7 @@ class BaseGPUDevice::StreamGroupFactory { int num_d2d_streams = options.experimental().num_dev_to_dev_copy_streams(); + if (num_d2d_streams == 0) num_d2d_streams = 1; if (num_d2d_streams < 1 || num_d2d_streams > 4) { LOG(ERROR) << "Illegal GPUOptions.experimental.num_dev_to_dev_copy_streams=" @@ -274,7 +276,7 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, tf_gpu_id_(tf_gpu_id), sync_every_op_(sync_every_op), max_streams_(max_streams) { - ProcessState::singleton()->EnableGPUDevice(); + GPUProcessState::singleton()->EnableGPUDevice(); } BaseGPUDevice::~BaseGPUDevice() { @@ -856,7 +858,7 @@ void BaseGPUDevice::ReinitializeDevice(OpKernelContext* context, static_cast<ConcretePerOpGpuDevice*>(device); DCHECK(concrete_device); const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>( - streams_[stream_id]->compute->implementation()->CudaStreamMemberHack()); + streams_[stream_id]->compute->implementation()->GpuStreamMemberHack()); concrete_device->Reinitialize(context, cuda_stream, tf_gpu_id_, allocator, scratch_[stream_id]); } @@ -1072,7 +1074,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, se::StreamExecutor* se = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); const se::DeviceDescription& desc = se->GetDeviceDescription(); - ProcessState* process_state = ProcessState::singleton(); + GPUProcessState* process_state = GPUProcessState::singleton(); Allocator* gpu_allocator = process_state->GetGPUAllocator( options.config.gpu_options(), tf_gpu_id, memory_limit); if (gpu_allocator == nullptr) { @@ -1092,7 +1094,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, BaseGPUDevice* gpu_device = CreateGPUDevice( options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality, tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator, - process_state->GetCPUAllocator(numa_node)); + ProcessState::singleton()->GetCPUAllocator(numa_node)); LOG(INFO) << "Created TensorFlow device (" << device_name << " with " << (stats.bytes_limit >> 20) << " MB memory) -> physical GPU (" << GetShortDeviceDescription(cuda_gpu_id, desc) << ")"; |