diff options
author | 2016-03-05 20:57:17 -0800 | |
---|---|---|
committer | 2016-03-05 21:01:35 -0800 | |
commit | 4c149223e3bcf18b0c30b876dfed443f75593387 (patch) | |
tree | 820ee5af86108b4e59f6543217c480729207d63b | |
parent | 5fff4f796149d0557c2ee69a333f4aaecdfbb05e (diff) |
TensorFlow: enable cuda host memory allocation for GPU compatible
buffers when copying to the CPU device.
Re-arranges some of the internal gpu libraries to be library vs. runtime
specific.
Change: 116472314
-rw-r--r-- | tensorflow/core/BUILD | 62 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/threadpool_device.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc/BUILD | 1 |
3 files changed, 56 insertions, 15 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7d01035d0d..e5095667a5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -964,6 +964,7 @@ tf_cuda_library( ), copts = tf_copts(), deps = [ + ":core_gpu_internal", ":framework", ":framework_internal", ":function_ops_op_lib", @@ -972,12 +973,46 @@ tf_cuda_library( ":lib", ":lib_internal", ":protos_all_cc", + ":stream_executor", "//tensorflow/core/kernels:required", "//third_party/eigen3", ], alwayslink = 1, ) +# This target should not link in any GPU runtime (CUDA) dependencies, +# only libraries for interfacing with GPUs that can be safely linked +# into CPU binaries. +cc_library( + name = "core_gpu_internal", + srcs = [ + "common_runtime/gpu/gpu_allocator_retry.cc", + "common_runtime/gpu/gpu_bfc_allocator.cc", + "common_runtime/gpu/gpu_debug_allocator.cc", + "common_runtime/gpu/gpu_init.cc", + "common_runtime/gpu/pool_allocator.cc", + "common_runtime/gpu/process_state.cc", + ], + hdrs = [ + "common_runtime/gpu/gpu_allocator_retry.h", + "common_runtime/gpu/gpu_bfc_allocator.h", + "common_runtime/gpu/gpu_debug_allocator.h", + "common_runtime/gpu/gpu_init.h", + "common_runtime/gpu/pool_allocator.h", + "common_runtime/gpu/process_state.h", + "common_runtime/gpu/visitable_allocator.h", + ], + copts = tf_copts(), + deps = [ + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":stream_executor", + ], +) + cc_library( name = "cuda", deps = [ @@ -1006,20 +1041,18 @@ tf_cuda_library( tf_cuda_library( name = "gpu_runtime", - srcs = glob( - [ - "common_runtime/gpu/*.cc", - ], - exclude = [ - "**/*main.cc", - "**/*test.cc", - "common_runtime/gpu/gpu_event_mgr.cc", - ], - ), - hdrs = glob( - ["common_runtime/gpu/*.h"], - exclude = ["common_runtime/gpu/gpu_event_mgr.h"], - ), + srcs = [ + "common_runtime/gpu/gpu_device.cc", + "common_runtime/gpu/gpu_device_factory.cc", + "common_runtime/gpu/gpu_stream_util.cc", + "common_runtime/gpu/gpu_util.cc", + "common_runtime/gpu/gpu_util_platform_specific.cc", + ], + hdrs = [ + "common_runtime/gpu/gpu_device.h", + "common_runtime/gpu/gpu_stream_util.h", + "common_runtime/gpu/gpu_util.h", + ], copts = tf_copts(), cuda_deps = [ ":cuda", @@ -1028,6 +1061,7 @@ tf_cuda_library( deps = [ ":core_cpu", ":core_cpu_internal", + ":core_gpu_internal", ":framework", ":framework_internal", ":gpu_lib", diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 6477e9a336..f2eb4d3e5f 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" @@ -52,7 +53,12 @@ void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { } Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) { - return allocator_; + ProcessState* ps = ProcessState::singleton(); + if (attr.gpu_compatible()) { + return ps->GetCUDAHostAllocator(0); + } else { + return allocator_; + } } Status ThreadPoolDevice::MakeTensorFromProto( diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index d9d016e67e..4d85b0f734 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -138,6 +138,7 @@ cc_library( ":grpc_call", ":grpc_util", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:core_gpu_internal", "//tensorflow/core:framework", "//tensorflow/core:gpu_lib", "//tensorflow/core:gpu_runtime", |