aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_launch_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc46
1 files changed, 39 insertions, 7 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 5ceccc769f..6134b8c694 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -64,11 +64,13 @@ xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
AllocationAttributes attrs;
attrs.no_retry_on_failure = !retry_on_failure;
- void* data =
- wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
- if (data == nullptr) {
- return errors::ResourceExhausted("Out of memory while trying to allocate ",
- size, " bytes.");
+ void* data = nullptr;
+ if (size != 0) {
+ data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
+ if (data == nullptr) {
+ return errors::ResourceExhausted(
+ "Out of memory while trying to allocate ", size, " bytes.");
+ }
}
return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
device_ordinal, this);
@@ -115,14 +117,22 @@ using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
- bool allocate_xla_tensors)
+ bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
- allocate_xla_tensors_(allocate_xla_tensors) {}
+ allocate_xla_tensors_(allocate_xla_tensors),
+ use_multiple_streams_(use_multiple_streams) {
+ if (use_multiple_streams_) {
+ CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
+ "be allocating XLA tensors!";
+ }
+}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables) {
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
@@ -140,6 +150,16 @@ void XlaComputationLaunchContext::PopulateInputs(
t = &(ctx->input(arg_num));
}
+ if (use_multiple_streams_) {
+ CHECK(stream) << "Must have a stream available when using XLA tensors!";
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
+ CHECK(xla_tensor);
+ if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) {
+ stream->ThenWaitFor(event);
+ xla_tensor->SetDefinedOn(stream);
+ }
+ }
+
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
@@ -248,6 +268,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
// tensor.
@@ -302,6 +328,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
*variable->tensor() = output_tensor;
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(