aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-03-07 20:57:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-08 17:18:17 -0800
commit52f49d7cdfde0fff2f52e069e1c588df6c3b9ee9 (patch)
tree4d60a260c91207ac68209e70a90b86408e8848fe /tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
parent19fd2b13725bf43c8e8236696f80b89612a5e879 (diff)
TensorFlow: create a CPU device that is only created when GPUs are also
potentially linked into the binary. This makes sure that the :core_cpu target will never have any GPU code linked in, which was confusing and weird. Change: 116618884
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_device_factory.cc')
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_factory.cc44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
index d37a55784d..d0726f235c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/threadpool_device.h"
namespace tensorflow {
@@ -61,6 +62,49 @@ class GPUDeviceFactory : public BaseGPUDeviceFactory {
REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory);
+//------------------------------------------------------------------------------
+// A CPUDevice that optimizes for interaction with GPUs in the
+// process.
+// -----------------------------------------------------------------------------
+class GPUCompatibleCPUDevice : public ThreadPoolDevice {
+ public:
+ GPUCompatibleCPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency,
+ Allocator* allocator)
+ : ThreadPoolDevice(options, name, memory_limit, bus_adjacency,
+ allocator) {}
+ ~GPUCompatibleCPUDevice() override {}
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ ProcessState* ps = ProcessState::singleton();
+ if (attr.gpu_compatible()) {
+ return ps->GetCUDAHostAllocator(0);
+ } else {
+ // Call the parent's implementation.
+ return ThreadPoolDevice::GetAllocator(attr);
+ }
+ }
+};
+
+// The associated factory.
+class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
+ public:
+ void CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override {
+ int n = 1;
+ auto iter = options.config.device_count().find("CPU");
+ if (iter != options.config.device_count().end()) {
+ n = iter->second;
+ }
+ for (int i = 0; i < n; i++) {
+ string name = strings::StrCat(name_prefix, "/cpu:", i);
+ devices->push_back(new GPUCompatibleCPUDevice(
+ options, name, Bytes(256 << 20), BUS_ANY, cpu_allocator()));
+ }
+ }
+};
+REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 50);
+
} // namespace tensorflow
#endif // GOOGLE_CUDA