aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 03:36:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 03:41:05 -0700
commitf22037abf5a6f4581f5fb6013f72f91747f22965 (patch)
treed28e0949d5c7bd436cc41f1af8602688167eebb0 /tensorflow/compiler/jit
parent44da41e4900c3fd481f12c9aa4c49679c9f32fa4 (diff)
Add a hint parameter to TransferLiteralToDeviceAsync that the implementation can use to accelerate transfers.
PiperOrigin-RevId: 215362667
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc15
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h3
2 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index af83c792e5..e083652978 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -75,8 +75,9 @@ XlaTransferManager::XlaTransferManager(
}
}
-Status XlaTransferManager::TransferLiteralToDevice(
- const Tensor& host_tensor, Tensor* device_tensor) const {
+Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor,
+ Tensor* device_tensor,
+ bool buffer_is_fresh) const {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
@@ -97,8 +98,11 @@ Status XlaTransferManager::TransferLiteralToDevice(
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
}
+ xla::TransferManager::TransferToDeviceHint hint =
+ buffer_is_fresh ? xla::TransferManager::kBufferUndefined
+ : xla::TransferManager::kNoHint;
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_.get(), *literal, shaped_buffer));
+ host_to_device_stream_.get(), *literal, shaped_buffer, hint));
if (UseMultipleStreams()) {
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
@@ -165,6 +169,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
TensorShape shape = shape_or_status.ValueOrDie();
+ bool buffer_is_fresh = false;
if (!xla_tensor->has_shaped_buffer()) {
Status s =
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
@@ -173,6 +178,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
done(s);
return;
}
+ buffer_is_fresh = true;
}
Status status;
@@ -183,7 +189,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
- status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
+ status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor,
+ buffer_is_fresh);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index df82421294..a4c0c296fc 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -67,7 +67,8 @@ class XlaTransferManager {
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
- Tensor* device_tensor) const;
+ Tensor* device_tensor,
+ bool buffer_is_fresh) const;
void TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor,
const StatusCallback& done) const;