diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/common_runtime/gpu/gpu_device_factory.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_device_factory.cc')
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device_factory.cc | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc new file mode 100644 index 0000000000..240ac47499 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc @@ -0,0 +1,52 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/common_runtime/gpu/gpu_device.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" + +namespace tensorflow { + +void RequireGPUDevice() {} + +class GPUDevice : public BaseGPUDevice { + public: + GPUDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id, + const string& physical_device_desc, Allocator* gpu_allocator, + Allocator* cpu_allocator) + : BaseGPUDevice(options, name, memory_limit, bus_adjacency, gpu_id, + physical_device_desc, gpu_allocator, cpu_allocator) {} + + Allocator* GetAllocator(AllocatorAttributes attr) override { + if (attr.on_host()) { + ProcessState* ps = ProcessState::singleton(); + if (attr.gpu_compatible()) { + return ps->GetCUDAHostAllocator(0); + } else { + return cpu_allocator_; + } + } else { + return gpu_allocator_; + } + } +}; + +class GPUDeviceFactory : public BaseGPUDeviceFactory { + private: + LocalDevice* CreateGPUDevice(const SessionOptions& options, + const string& name, Bytes memory_limit, + BusAdjacency bus_adjacency, int gpu_id, + const string& physical_device_desc, + Allocator* gpu_allocator, + Allocator* cpu_allocator) override { + return new GPUDevice(options, name, memory_limit, bus_adjacency, gpu_id, + physical_device_desc, gpu_allocator, cpu_allocator); + } +}; + +REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA |