diff options
author | 2018-04-06 03:23:54 -0700 | |
---|---|---|
committer | 2018-04-06 03:26:08 -0700 | |
commit | 58df8c97a7dc2ed2159e8137312fa29c0d7bcf67 (patch) | |
tree | 1d586d1bb7bbaad81f227b746febd801dc289527 /tensorflow/compiler/jit/xla_launch_util.cc | |
parent | 4eefd3a5e4a7f5432be7fd3981071dc6b727349f (diff) |
internal change
PiperOrigin-RevId: 191869400
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.cc | 34 |
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 354be1e1b5..50b0061d69 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -16,12 +16,14 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -165,6 +167,8 @@ void XlaComputationLaunchContext::PopulateOutputs( // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); + VLOG(2) << "Result tuple shape (on device): " + << output->on_device_shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); @@ -179,6 +183,10 @@ void XlaComputationLaunchContext::PopulateOutputs( const size_t total_bytes = const_tensor.TotalBytes(); if (stream && total_bytes > 0) { // Copy host -> device. (Empty tensors don't have backing buffers.) + // Manually allocate memory using an XlaTensorBuffer so we can allocate + // as much memory as the device requires (as given by + // GetByteSizeRequirement). This avoids XlaTransferManager having to + // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; OP_REQUIRES_OK( @@ -189,15 +197,23 @@ void XlaComputationLaunchContext::PopulateOutputs( client_, stream->parent()->device_ordinal())); } - const void* src_ptr = DMAHelper::base(&const_tensor); - gpu::DeviceMemoryBase dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*output_tensor); - // Memcpying asynchronously is safe for the GPU, but the CPU uses a - // shared allocator so hold a reference to the copied-to buffer until - // complete. - TensorReference ref(*output_tensor); - stream->ThenMemcpy(&dst_ptr, src_ptr, total_bytes); - stream->ThenDoHostCallback([ref] { ref.Unref(); }); + Device* device = dynamic_cast<Device*>(ctx->device()); + OP_REQUIRES(ctx, device != nullptr, + errors::Internal("DeviceBase was not a Device.")); + ctx->op_device_context()->CopyCPUTensorToDevice( + &const_tensor, device, output_tensor, + [&](Status status) { TF_CHECK_OK(status); }); + + if (device->device_type() == DEVICE_GPU) { + // The GPUDeviceContext enqueues the host->device transfer in a + // separate stream from the main compute stream. We must ensure the + // compute stream is synchronized with the host->device transfer + // stream now otherwise we will create a race condition. + auto* gpu_device_context = + static_cast<GPUDeviceContext*>(ctx->op_device_context()); + gpu_device_context->stream()->ThenWaitFor( + gpu_device_context->host_to_device_stream()); + } } else { // No copy required. ctx->set_output(i, const_tensor); |