diff options
author | 2018-10-02 03:36:14 -0700 | |
---|---|---|
committer | 2018-10-02 03:41:05 -0700 | |
commit | f22037abf5a6f4581f5fb6013f72f91747f22965 (patch) | |
tree | d28e0949d5c7bd436cc41f1af8602688167eebb0 /tensorflow/compiler/jit | |
parent | 44da41e4900c3fd481f12c9aa4c49679c9f32fa4 (diff) |
Add a hint parameter to TransferLiteralToDeviceAsync that the implementation can use to accelerate transfers.
PiperOrigin-RevId: 215362667
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.h | 3 |
2 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index af83c792e5..e083652978 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -75,8 +75,9 @@ XlaTransferManager::XlaTransferManager( } } -Status XlaTransferManager::TransferLiteralToDevice( - const Tensor& host_tensor, Tensor* device_tensor) const { +Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor, + Tensor* device_tensor, + bool buffer_is_fresh) const { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); @@ -97,8 +98,11 @@ Status XlaTransferManager::TransferLiteralToDevice( // synchronized. host_to_device_stream_->ThenWaitFor(stream_.get()); } + xla::TransferManager::TransferToDeviceHint hint = + buffer_is_fresh ? xla::TransferManager::kBufferUndefined + : xla::TransferManager::kNoHint; TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_.get(), *literal, shaped_buffer)); + host_to_device_stream_.get(), *literal, shaped_buffer, hint)); if (UseMultipleStreams()) { auto event = std::make_shared<se::Event>(stream_->parent()); TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; @@ -165,6 +169,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, return; } TensorShape shape = shape_or_status.ValueOrDie(); + bool buffer_is_fresh = false; if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, @@ -173,6 +178,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, done(s); return; } + buffer_is_fresh = true; } Status status; @@ -183,7 +189,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, "Tensor::CopyFrom failed when copying from CPU to XLA device")); return; } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor, + buffer_is_fresh); } else { se::DeviceMemoryBase dev_dst_ptr = XlaTensor::DeviceMemoryFromTensor(*device_tensor); diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index df82421294..a4c0c296fc 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -67,7 +67,8 @@ class XlaTransferManager { private: Status TransferLiteralToDevice(const Tensor& host_tensor, - Tensor* device_tensor) const; + Tensor* device_tensor, + bool buffer_is_fresh) const; void TransferLiteralFromDevice(Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const; |