aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/copy_tensor.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-05 21:52:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-05 21:56:36 -0700
commit02ae1e2e781b8e049d1fc1ab7b52f6ee7edb4423 (patch)
tree134171731db099fd5efcf600866e7ad8b8fff22d /tensorflow/core/common_runtime/copy_tensor.cc
parentc42013f103ab8f6588cff1a8e59bc1ef81435bcc (diff)
[tf.data] Add support for copying `Optional` variants to/from GPU.
PiperOrigin-RevId: 207490563
Diffstat (limited to 'tensorflow/core/common_runtime/copy_tensor.cc')
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index 630b3702c8..f8cb854b52 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -340,4 +340,30 @@ Status CopyTensor::Register(DeviceType sender_device_type,
return Status::OK();
}
+namespace {
+
+// The following registrations enable a DT_VARIANT tensor element that contains
+// a wrapped `tensorflow::Tensor` to be copied between devices.
+static Status WrappedTensorDeviceCopy(
+ const Tensor& from, Tensor* to,
+ const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
+ if (DMAHelper::CanUseDMA(&from)) {
+ TF_RETURN_IF_ERROR(copy(from, to));
+ } else {
+ *to = from;
+ }
+
+ return Status::OK();
+}
+
+#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
+
+REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
+REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
+REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
+
+} // namespace
+
} // namespace tensorflow