aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_launch_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc60
1 files changed, 55 insertions, 5 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index d0c7a93651..616c3ed2a2 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -115,14 +115,22 @@ using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
- bool allocate_xla_tensors)
+ bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
- allocate_xla_tensors_(allocate_xla_tensors) {}
+ allocate_xla_tensors_(allocate_xla_tensors),
+ use_multiple_streams_(use_multiple_streams) {
+ if (use_multiple_streams_) {
+ CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
+ "be allocating XLA tensors!";
+ }
+}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables) {
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
@@ -140,6 +148,16 @@ void XlaComputationLaunchContext::PopulateInputs(
t = &(ctx->input(arg_num));
}
+ if (use_multiple_streams_) {
+ CHECK(stream) << "Must have a stream available when using XLA tensors!";
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
+ CHECK(xla_tensor);
+ if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) {
+ stream->ThenWaitFor(event);
+ xla_tensor->SetDefinedOn(stream);
+ }
+ }
+
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
@@ -176,6 +194,21 @@ void XlaComputationLaunchContext::PopulateOutputs(
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
+ // If the on-host-shape isn't a tuple, create a new single-element tuple
+ // buffer with a nullptr root index table. This allows the code below to treat
+ // output as a tuple unconditionally.
+ if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) {
+ ShapedBuffer nontuple_buffer = output.release();
+ ShapedBuffer buffer(
+ xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
+ xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
+ output.platform(), output.device_ordinal());
+ buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
+ /*source_base_index=*/{},
+ /*target_base_index=*/{0});
+ output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
+ }
+
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@@ -230,9 +263,20 @@ void XlaComputationLaunchContext::PopulateOutputs(
Tensor* output_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
- CHECK(xla_tensor);
- xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
- ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (xla_tensor) {
+ xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
+ ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
+ } else {
+ // xla_tensor wasn't valid, which must mean this is a zero-element
+ // tensor.
+ CHECK_EQ(output_tensor->TotalBytes(), 0);
+ }
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
@@ -282,6 +326,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
*variable->tensor() = output_tensor;
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(