aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_device_factory.cc')
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_factory.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
new file mode 100644
index 0000000000..240ac47499
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -0,0 +1,52 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+
+namespace tensorflow {
+
+void RequireGPUDevice() {}
+
+class GPUDevice : public BaseGPUDevice {
+ public:
+ GPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc, Allocator* gpu_allocator,
+ Allocator* cpu_allocator)
+ : BaseGPUDevice(options, name, memory_limit, bus_adjacency, gpu_id,
+ physical_device_desc, gpu_allocator, cpu_allocator) {}
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ if (attr.on_host()) {
+ ProcessState* ps = ProcessState::singleton();
+ if (attr.gpu_compatible()) {
+ return ps->GetCUDAHostAllocator(0);
+ } else {
+ return cpu_allocator_;
+ }
+ } else {
+ return gpu_allocator_;
+ }
+ }
+};
+
+class GPUDeviceFactory : public BaseGPUDeviceFactory {
+ private:
+ LocalDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc,
+ Allocator* gpu_allocator,
+ Allocator* cpu_allocator) override {
+ return new GPUDevice(options, name, memory_limit, bus_adjacency, gpu_id,
+ physical_device_desc, gpu_allocator, cpu_allocator);
+ }
+};
+
+REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA