aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_launch_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-06 03:23:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 03:26:08 -0700
commit58df8c97a7dc2ed2159e8137312fa29c0d7bcf67 (patch)
tree1d586d1bb7bbaad81f227b746febd801dc289527 /tensorflow/compiler/jit/xla_launch_util.cc
parent4eefd3a5e4a7f5432be7fd3981071dc6b727349f (diff)
internal change
PiperOrigin-RevId: 191869400
Diffstat (limited to 'tensorflow/compiler/jit/xla_launch_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc34
1 files changed, 25 insertions, 9 deletions
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 354be1e1b5..50b0061d69 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -16,12 +16,14 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@@ -165,6 +167,8 @@ void XlaComputationLaunchContext::PopulateOutputs(
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString();
+ VLOG(2) << "Result tuple shape (on device): "
+ << output->on_device_shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
@@ -179,6 +183,10 @@ void XlaComputationLaunchContext::PopulateOutputs(
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
+ // Manually allocate memory using an XlaTensorBuffer so we can allocate
+ // as much memory as the device requires (as given by
+ // GetByteSizeRequirement). This avoids XlaTransferManager having to
+ // reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
OP_REQUIRES_OK(
@@ -189,15 +197,23 @@ void XlaComputationLaunchContext::PopulateOutputs(
client_, stream->parent()->device_ordinal()));
}
- const void* src_ptr = DMAHelper::base(&const_tensor);
- gpu::DeviceMemoryBase dst_ptr =
- XlaTensor::DeviceMemoryFromTensor(*output_tensor);
- // Memcpying asynchronously is safe for the GPU, but the CPU uses a
- // shared allocator so hold a reference to the copied-to buffer until
- // complete.
- TensorReference ref(*output_tensor);
- stream->ThenMemcpy(&dst_ptr, src_ptr, total_bytes);
- stream->ThenDoHostCallback([ref] { ref.Unref(); });
+ Device* device = dynamic_cast<Device*>(ctx->device());
+ OP_REQUIRES(ctx, device != nullptr,
+ errors::Internal("DeviceBase was not a Device."));
+ ctx->op_device_context()->CopyCPUTensorToDevice(
+ &const_tensor, device, output_tensor,
+ [&](Status status) { TF_CHECK_OK(status); });
+
+ if (device->device_type() == DEVICE_GPU) {
+ // The GPUDeviceContext enqueues the host->device transfer in a
+ // separate stream from the main compute stream. We must ensure the
+ // compute stream is synchronized with the host->device transfer
+ // stream now otherwise we will create a race condition.
+ auto* gpu_device_context =
+ static_cast<GPUDeviceContext*>(ctx->op_device_context());
+ gpu_device_context->stream()->ThenWaitFor(
+ gpu_device_context->host_to_device_stream());
+ }
} else {
// No copy required.
ctx->set_output(i, const_tensor);