diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.cc | 60 |
1 files changed, 55 insertions, 5 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index d0c7a93651..616c3ed2a2 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -115,14 +115,22 @@ using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors) + bool allocate_xla_tensors, bool use_multiple_streams) : client_(client), xla_allocator_(xla_allocator), - allocate_xla_tensors_(allocate_xla_tensors) {} + allocate_xla_tensors_(allocate_xla_tensors), + use_multiple_streams_(use_multiple_streams) { + if (use_multiple_streams_) { + CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must " + "be allocating XLA tensors!"; + } +} void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, const std::map<int, OptionalTensor>& variables) { + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); arg_buffers_.resize(kernel->xla_input_shapes.size()); @@ -140,6 +148,16 @@ void XlaComputationLaunchContext::PopulateInputs( t = &(ctx->input(arg_num)); } + if (use_multiple_streams_) { + CHECK(stream) << "Must have a stream available when using XLA tensors!"; + XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor); + if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) { + stream->ThenWaitFor(event); + xla_tensor->SetDefinedOn(stream); + } + } + const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); if (xla::ShapeUtil::IsTuple(on_device_shape)) { @@ -176,6 +194,21 @@ void XlaComputationLaunchContext::PopulateOutputs( } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + // If the on-host-shape isn't a tuple, create a new single-element tuple + // buffer with a nullptr root index table. This allows the code below to treat + // output as a tuple unconditionally. + if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) { + ShapedBuffer nontuple_buffer = output.release(); + ShapedBuffer buffer( + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}), + output.platform(), output.device_ordinal()); + buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(), + /*source_base_index=*/{}, + /*target_base_index=*/{0}); + output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -230,9 +263,20 @@ void XlaComputationLaunchContext::PopulateOutputs( Tensor* output_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); + } + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); + } } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); @@ -282,6 +326,12 @@ void XlaComputationLaunchContext::PopulateOutputs( CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); + if (use_multiple_streams_) { + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); + } *variable->tensor() = output_tensor; } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( |