aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Smit Hinsu <hinsu@google.com>2018-05-21 17:42:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 17:44:41 -0700
commitb1139814f91c5216eb5ff229ee7e1982e5f4e888 (patch)
tree7f85c8229bfd47eeba49890aa75b59c8680e619c
parentd913a243196fa07d4728c8f7c1ce6444ecd086eb (diff)
Introduce an option to allocate CUDA unified memory
PiperOrigin-RevId: 197490523
-rw-r--r--tensorflow/core/BUILD32
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc4
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h20
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc14
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_test.cc140
-rw-r--r--tensorflow/core/protobuf/config.proto44
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc32
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.h10
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h8
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h9
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc14
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h10
12 files changed, 289 insertions, 48 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 0d99244147..4146e26489 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -94,6 +94,7 @@ load(
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
@@ -3386,6 +3387,37 @@ tf_cc_tests_gpu(
],
)
+tf_cuda_cc_test(
+ name = "gpu_device_unified_memory_test",
+ size = "small",
+ srcs = [
+ "common_runtime/gpu/gpu_device_test.cc",
+ ],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ # Runs test on a Guitar cluster that uses P100s to test unified memory
+ # allocations.
+ tags = tf_cuda_tests_tags() + [
+ "guitar",
+ "multi_gpu",
+ ],
+ deps = [
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session",
+ ":framework",
+ ":framework_internal",
+ ":gpu_id",
+ ":lib",
+ ":lib_internal",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_cc_test_gpu(
name = "cuda_libdevice_path_test",
size = "small",
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 2f7fbbbec2..2d4c8d0201 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -31,7 +31,9 @@ GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
const string& name)
: BFCAllocator(
new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie()),
+ GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(),
+ gpu_options.per_process_gpu_memory_fraction() > 1.0 ||
+ gpu_options.experimental().use_unified_memory()),
total_memory, gpu_options.allow_growth(), name) {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index ad142e9982..a3e0d0734f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -50,8 +50,9 @@ class GPUBFCAllocator : public BFCAllocator {
class GPUMemAllocator : public SubAllocator {
public:
// Note: stream_exec cannot be null.
- explicit GPUMemAllocator(se::StreamExecutor* stream_exec)
- : stream_exec_(stream_exec) {
+ explicit GPUMemAllocator(se::StreamExecutor* stream_exec,
+ bool use_unified_memory)
+ : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) {
CHECK(stream_exec_ != nullptr);
}
~GPUMemAllocator() override {}
@@ -59,20 +60,29 @@ class GPUMemAllocator : public SubAllocator {
void* Alloc(size_t alignment, size_t num_bytes) override {
void* ptr = nullptr;
if (num_bytes > 0) {
- ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
+ if (use_unified_memory_) {
+ ptr = stream_exec_->UnifiedMemoryAllocate(num_bytes);
+ } else {
+ ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
+ }
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
- se::DeviceMemoryBase gpu_ptr(ptr);
- stream_exec_->Deallocate(&gpu_ptr);
+ if (use_unified_memory_) {
+ stream_exec_->UnifiedMemoryDeallocate(ptr);
+ } else {
+ se::DeviceMemoryBase gpu_ptr(ptr);
+ stream_exec_->Deallocate(&gpu_ptr);
+ }
}
}
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const bool use_unified_memory_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 48d4c52bb4..cf5d11ec8b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -809,6 +809,20 @@ Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options,
int64 allocated_memory = 0;
const double per_process_gpu_memory_fraction =
gpu_options.per_process_gpu_memory_fraction();
+ if (per_process_gpu_memory_fraction > 1.0 ||
+ gpu_options.experimental().use_unified_memory()) {
+ int cc_major = 0, cc_minor = 0;
+ if (!se->GetDeviceDescription().cuda_compute_capability(&cc_major,
+ &cc_minor)) {
+ return errors::Internal("Failed to get compute capability for device.");
+ }
+ if (cc_major < 6) {
+ return errors::Internal(
+ "Unified memory on GPUs with compute capability lower than 6.0 "
+ "(pre-Pascal class GPUs) does not support oversubscription.");
+ }
+ }
+
if (per_process_gpu_memory_fraction == 0) {
allocated_memory = available_memory;
const int64 min_system_memory = MinSystemMemory(available_memory);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index bb00173d1e..5c6cb43eff 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -17,16 +17,45 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace {
const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0";
+int64 GetTotalGPUMemory(CudaGpuId gpu_id) {
+ se::StreamExecutor* se =
+ GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+
+ int64 total_memory, available_memory;
+ CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
+ return total_memory;
+}
+
+Status GetComputeCapability(CudaGpuId gpu_id, int* cc_major, int* cc_minor) {
+ se::StreamExecutor* se =
+ GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ if (!se->GetDeviceDescription().cuda_compute_capability(cc_major, cc_minor)) {
+ *cc_major = 0;
+ *cc_minor = 0;
+ return errors::Internal("Failed to get compute capability for device.");
+ }
+ return Status::OK();
+}
+
+void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
+ << s << ", expected substring " << substr;
+}
+} // namespace
+
class GPUDeviceTest : public ::testing::Test {
public:
void TearDown() override { ProcessState::singleton()->TestOnlyReset(); }
@@ -52,11 +81,6 @@ class GPUDeviceTest : public ::testing::Test {
}
return options;
}
-
- static bool StartsWith(const string& lhs, const string& rhs) {
- if (rhs.length() > lhs.length()) return false;
- return lhs.substr(0, rhs.length()) == rhs;
- }
};
TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
@@ -65,8 +89,7 @@ TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StartsWith(status.error_message(), "Could not parse entry"))
- << status;
+ ExpectErrorMessageSubstr(status, "Could not parse entry");
}
TEST_F(GPUDeviceTest, InvalidGpuId) {
@@ -75,9 +98,8 @@ TEST_F(GPUDeviceTest, InvalidGpuId) {
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StartsWith(status.error_message(),
- "'visible_device_list' listed an invalid GPU id"))
- << status;
+ ExpectErrorMessageSubstr(status,
+ "'visible_device_list' listed an invalid GPU id");
}
TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
@@ -86,9 +108,8 @@ TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StartsWith(status.error_message(),
- "visible_device_list contained a duplicate entry"))
- << status;
+ ExpectErrorMessageSubstr(status,
+ "visible_device_list contained a duplicate entry");
}
TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) {
@@ -97,9 +118,8 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) {
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StartsWith(status.error_message(),
- "It's invalid to set per_process_gpu_memory_fraction"))
- << status;
+ ExpectErrorMessageSubstr(
+ status, "It's invalid to set per_process_gpu_memory_fraction");
}
TEST_F(GPUDeviceTest, GpuDeviceCountTooSmall) {
@@ -110,9 +130,8 @@ TEST_F(GPUDeviceTest, GpuDeviceCountTooSmall) {
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN);
- EXPECT_TRUE(StartsWith(status.error_message(),
- "Not enough GPUs to create virtual devices."))
- << status;
+ ExpectErrorMessageSubstr(status,
+ "Not enough GPUs to create virtual devices.");
}
TEST_F(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) {
@@ -123,9 +142,8 @@ TEST_F(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) {
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN);
- EXPECT_TRUE(StartsWith(status.error_message(),
- "Not enough GPUs to create virtual devices."))
- << status;
+ ExpectErrorMessageSubstr(status,
+ "Not enough GPUs to create virtual devices.");
}
TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
@@ -138,11 +156,11 @@ TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) {
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
- EXPECT_TRUE(StartsWith(status.error_message(),
- "The number of GPUs in visible_device_list doesn't "
- "match the number of elements in the virtual_devices "
- "list."))
- << status;
+ ExpectErrorMessageSubstr(
+ status,
+ "The number of GPUs in visible_device_list doesn't "
+ "match the number of elements in the virtual_devices "
+ "list.");
}
TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) {
@@ -153,7 +171,7 @@ TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) {
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
- for (auto d : devices) delete d;
+ gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) {
@@ -165,7 +183,7 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) {
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
- for (auto d : devices) delete d;
+ gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) {
@@ -175,7 +193,7 @@ TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) {
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
- for (auto d : devices) delete d;
+ gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
@@ -198,7 +216,67 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
devices[1]->attributes().locality().links().link(0).type());
EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength,
devices[1]->attributes().locality().links().link(0).strength());
- for (auto d : devices) delete d;
+ gtl::STLDeleteElements(&devices);
+}
+
+// Enabling unified memory on pre-Pascal GPUs results in an initialization
+// error.
+TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
+ int cc_major, cc_minor;
+ TF_ASSERT_OK(GetComputeCapability(CudaGpuId(0), &cc_major, &cc_minor));
+ // Exit early while running on Pascal or later GPUs.
+ if (cc_major >= 6) {
+ return;
+ }
+
+ SessionOptions opts = MakeSessionOptions("0", /*memory_fraction=*/1.2);
+ opts.config.mutable_gpu_options()
+ ->mutable_experimental()
+ ->set_use_unified_memory(true);
+ std::vector<tensorflow::Device*> devices;
+ Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
+ opts, kDeviceNamePrefix, &devices);
+ EXPECT_EQ(status.code(), error::INTERNAL);
+ ExpectErrorMessageSubstr(status, "does not support oversubscription.");
+}
+
+// Enabling unified memory on Pascal or later GPUs makes it possible to allocate
+// more memory than what is available on the device.
+TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
+ static constexpr double kGpuMemoryFraction = 1.2;
+ static constexpr CudaGpuId kCudaGpuId(0);
+
+ int cc_major, cc_minor;
+ TF_ASSERT_OK(GetComputeCapability(kCudaGpuId, &cc_major, &cc_minor));
+ // Exit early if running on pre-Pascal GPUs.
+ if (cc_major < 6) {
+ LOG(INFO)
+ << "Unified memory allocation is not supported with pre-Pascal GPUs.";
+ return;
+ }
+
+ SessionOptions opts = MakeSessionOptions("0", kGpuMemoryFraction);
+ std::vector<tensorflow::Device*> devices;
+ TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
+ opts, kDeviceNamePrefix, &devices));
+ ASSERT_EQ(1, devices.size());
+
+ int64 memory_limit = devices[0]->attributes().memory_limit();
+ ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kCudaGpuId) *
+ kGpuMemoryFraction));
+
+ AllocatorAttributes allocator_attributes = AllocatorAttributes();
+ allocator_attributes.set_gpu_compatible(true);
+ Allocator* allocator = devices[0]->GetAllocator(allocator_attributes);
+
+ // Try to allocate all the available memory after rounding down to the nearest
+ // multiple of MB.
+ void* ptr = allocator->AllocateRaw(Allocator::kAllocatorAlignment,
+ (memory_limit >> 20) << 20);
+ EXPECT_NE(ptr, nullptr);
+ allocator->DeallocateRaw(ptr);
+
+ gtl::STLDeleteElements(&devices);
}
} // namespace tensorflow
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index c1a0075b64..6cd067afcb 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -14,12 +14,29 @@ import "tensorflow/core/protobuf/cluster.proto";
import "tensorflow/core/protobuf/rewriter_config.proto";
message GPUOptions {
- // A value between 0 and 1 that indicates what fraction of the
- // available GPU memory to pre-allocate for each process. 1 means
- // to pre-allocate all of the GPU memory, 0.5 means the process
- // allocates ~50% of the available GPU memory.
+ // Fraction of the available GPU memory to allocate for each process.
+ // 1 means to allocate all of the GPU memory, 0.5 means the process
+ // allocates up to ~50% of the available GPU memory.
+ //
+ // GPU memory is pre-allocated unless the allow_growth option is enabled.
+ //
+ // If greater than 1.0, uses CUDA unified memory to potentially oversubscribe
+ // the amount of memory available on the GPU device by using host memory as a
+ // swap space. Accessing memory not available on the device will be
+ // significantly slower as that would require memory transfer between the host
+ // and the device. Options to reduce the memory requirement should be
+ // considered before enabling this option as this may come with a negative
+ // performance impact. Oversubscription using the unified memory requires
+ // Pascal class or newer GPUs and it is currently only supported on the Linux
+ // operating system. See
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-requirements
+ // for the detailed requirements.
double per_process_gpu_memory_fraction = 1;
+ // If true, the allocator does not pre-allocate the entire specified
+ // GPU memory region, instead starting small and growing as needed.
+ bool allow_growth = 4;
+
// The type of GPU allocation strategy to use.
//
// Allowed values:
@@ -35,10 +52,6 @@ message GPUOptions {
// a reasonable default (several MBs).
int64 deferred_deletion_bytes = 3;
- // If true, the allocator does not pre-allocate the entire specified
- // GPU memory region, instead starting small and growing as needed.
- bool allow_growth = 4;
-
// A comma-separated list of GPU ids that determines the 'visible'
// to 'virtual' mapping of GPU devices. For example, if TensorFlow
// can see 8 GPU devices in the process, and one wanted to map
@@ -82,9 +95,6 @@ message GPUOptions {
// the overall host system performance.
bool force_gpu_compatible = 8;
- // Everything inside Experimental is subject to change and is not subject
- // to API stability guarantees in
- // https://www.tensorflow.org/programmers_guide/version_compat.
message Experimental {
// Configuration for breaking down a visible GPU into multiple "virtual"
// devices.
@@ -124,8 +134,20 @@ message GPUOptions {
// different settings in different sessions within same process will
// result in undefined behavior.
repeated VirtualDevices virtual_devices = 1;
+
+ // If true, uses CUDA unified memory for memory allocations. If
+ // per_process_gpu_memory_fraction option is greater than 1.0, then unified
+ // memory is used regardless of the value for this field. See comments for
+ // per_process_gpu_memory_fraction field for more details and requirements
+ // of the unified memory. This option is useful to oversubscribe memory if
+ // multiple processes are sharing a single GPU while individually using less
+ // than 1.0 per process memory fraction.
+ bool use_unified_memory = 2;
}
+ // Everything inside experimental is subject to change and is not subject
+ // to API stability guarantees in
+ // https://www.tensorflow.org/programmers_guide/version_compat.
Experimental experimental = 9;
};
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index 273ed83997..09e9f9f758 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <set>
#include <utility>
+#include "cuda/include/cuda_runtime.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/lib/casts.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -924,6 +925,37 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
}
}
+/* static */ void *CUDADriver::UnifiedMemoryAllocate(CudaContext *context,
+ uint64 bytes) {
+ ScopedActivateContext activation(context);
+ CUdeviceptr result = 0;
+ // "Portable" memory is visible to all CUDA contexts. Safe for our use model.
+ CUresult res = cuMemAllocManaged(&result, bytes, CU_MEM_ATTACH_GLOBAL);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to alloc " << bytes
+ << " bytes unified memory; result: " << ToString(res);
+ return nullptr;
+ }
+ void *ptr = reinterpret_cast<void *>(result);
+ VLOG(2) << "allocated " << ptr << " for context " << context << " of "
+ << bytes << " bytes in unified memory";
+ return ptr;
+}
+
+/* static */ void CUDADriver::UnifiedMemoryDeallocate(CudaContext *context,
+ void *location) {
+ ScopedActivateContext activation(context);
+ CUdeviceptr pointer = port::bit_cast<CUdeviceptr>(location);
+ CUresult res = cuMemFree(pointer);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to free unified memory at " << location
+ << "; result: " << ToString(res);
+ } else {
+ VLOG(2) << "deallocated unified memory at " << location << " for context "
+ << context;
+ }
+}
+
/* static */ void *CUDADriver::HostAllocate(CudaContext *context,
uint64 bytes) {
ScopedActivateContext activation(context);
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h
index b952cfaf68..3713a5b7b9 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.h
+++ b/tensorflow/stream_executor/cuda/cuda_driver.h
@@ -106,6 +106,16 @@ class CUDADriver {
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a
static void DeviceDeallocate(CudaContext* context, void *location);
+ // Allocates a unified memory space of size bytes associated with the given
+ // context via cuMemAllocManaged.
+ // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb347ded34dc326af404aa02af5388a32
+ static void* UnifiedMemoryAllocate(CudaContext* context, uint64 bytes);
+
+ // Deallocates a unified memory space of size bytes associated with the given
+ // context via cuMemFree.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a
+ static void UnifiedMemoryDeallocate(CudaContext* context, void* location);
+
// Allocates page-locked and CUDA-registered memory on the host via
// cuMemAllocHost.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gdd8311286d2c2691605362c689bc64e0
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index f686685474..773cbfb8a1 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -74,6 +74,14 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
void Deallocate(DeviceMemoryBase *mem) override;
+ void *UnifiedMemoryAllocate(uint64 size) override {
+ return CUDADriver::UnifiedMemoryAllocate(context_, size);
+ }
+
+ void UnifiedMemoryDeallocate(void *location) override {
+ return CUDADriver::UnifiedMemoryDeallocate(context_, location);
+ }
+
// CUDA allocation/registration functions are necessary because the driver
// internally sets up buffers for DMA operations (and page locks them).
// There's no external interface for us to otherwise control these DMA
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 2584c92f0c..9c989b971d 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -174,6 +174,15 @@ class StreamExecutorInterface {
virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset,
uint64 size) = 0;
virtual void Deallocate(DeviceMemoryBase *mem) = 0;
+ // Allocates unified memory space of the given size, if supported.
+ // See
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
+ // for more details on unified memory.
+ virtual void *UnifiedMemoryAllocate(uint64 size) { return nullptr; }
+
+ // Deallocates unified memory space previously allocated with
+ // UnifiedMemoryAllocate.
+ virtual void UnifiedMemoryDeallocate(void *mem) {}
virtual void *HostMemoryAllocate(uint64 size) = 0;
virtual void HostMemoryDeallocate(void *mem) = 0;
virtual bool HostMemoryRegister(void *mem, uint64 size) = 0;
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index eecd5bfe1f..b222a4d82a 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -464,6 +464,20 @@ bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
return implementation_->GetSymbol(symbol_name, mem, bytes);
}
+void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
+ void *buffer = implementation_->UnifiedMemoryAllocate(bytes);
+ VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
+ << ") returns " << buffer << StackTraceIfVLOG10();
+ return buffer;
+}
+
+void StreamExecutor::UnifiedMemoryDeallocate(void *location) {
+ VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
+ << location << ")" << StackTraceIfVLOG10();
+
+ return implementation_->UnifiedMemoryDeallocate(location);
+}
+
void *StreamExecutor::HostMemoryAllocate(uint64 size) {
void *buffer = implementation_->HostMemoryAllocate(size);
VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index e426cf9931..ad80a1ba25 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -190,6 +190,16 @@ class StreamExecutor {
// activated.
void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
+ // Allocates unified memory space of the given size, if supported.
+ // See
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-unified-memory-programming-hd
+ // for more details on unified memory.
+ void *UnifiedMemoryAllocate(uint64 bytes);
+
+ // Deallocates unified memory space previously allocated with
+ // UnifiedMemoryAllocate.
+ void UnifiedMemoryDeallocate(void *location);
+
// Allocates a region of host memory and registers it with the platform API.
// Memory allocated in this manner (or allocated and registered with
// HostMemoryRegister() is required for use in asynchronous memcpy operations,