aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_launch_util.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-05-09 11:22:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 11:34:00 -0700
commit7baa9ffe735adfa11c987c435216943767530269 (patch)
treeccb4615481b2d062a800b55f4b63d78e6c62505e /tensorflow/compiler/jit/xla_launch_util.cc
parenta01d9f7dfb58c72ea78ed560c78f99e96223ea76 (diff)
[XLA] Make XLA's memory allocator return an owning smart pointer.
Previously, xla::DeviceMemoryAllocator::Allocate returned a stream_executor::DeviceMemoryBase. This is morally equivalent to a raw pointer: It's on you the user to call Deallocate(). Unfortunately we ~never got this right. Essentially all users of Allocate() call it in a loop, and TF_RETURN_IF_ERROR within the loop. If any of these allocations fails (mostly commonly, due to OOM), we leak everything we've allocated up until then. This patch changes our API so that it returns an owning pointer. Now things mostly Just Work. Also worth calling out: The lambda in CpuExecutable::ExecuteOnStream passed to ExecuteComputeFunction almost certainly had multithreaded use-after-free bugs. This patch fixes them. PiperOrigin-RevId: 196000535
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index e12e88fcc9..6a0f557627 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -60,7 +60,7 @@ XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
XlaAllocator::~XlaAllocator() {}
-xla::StatusOr<se::DeviceMemoryBase> XlaAllocator::Allocate(
+xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
AllocationAttributes attrs;
attrs.no_retry_on_failure = !retry_on_failure;
@@ -69,13 +69,13 @@ xla::StatusOr<se::DeviceMemoryBase> XlaAllocator::Allocate(
if (data == nullptr) {
return errors::ResourceExhausted("Out of memory while trying to allocate ",
size, " bytes.");
- } else {
- return se::DeviceMemoryBase(data, size);
}
+ return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
+ device_ordinal, this);
}
-Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) {
- wrapped_->DeallocateRaw(mem->opaque());
+Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
+ wrapped_->DeallocateRaw(mem.opaque());
return Status::OK();
}
@@ -241,7 +241,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
- output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
@@ -291,7 +291,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator);
- output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
*variable->tensor() = output_tensor;
}
++output_num;