diff options
author | Derek Murray <mrry@google.com> | 2018-08-05 21:52:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-05 21:56:36 -0700 |
commit | 02ae1e2e781b8e049d1fc1ab7b52f6ee7edb4423 (patch) | |
tree | 134171731db099fd5efcf600866e7ad8b8fff22d /tensorflow/core/common_runtime/copy_tensor.cc | |
parent | c42013f103ab8f6588cff1a8e59bc1ef81435bcc (diff) |
[tf.data] Add support for copying `Optional` variants to/from GPU.
PiperOrigin-RevId: 207490563
Diffstat (limited to 'tensorflow/core/common_runtime/copy_tensor.cc')
-rw-r--r-- | tensorflow/core/common_runtime/copy_tensor.cc | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index 630b3702c8..f8cb854b52 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -340,4 +340,30 @@ Status CopyTensor::Register(DeviceType sender_device_type, return Status::OK(); } +namespace { + +// The following registrations enable a DT_VARIANT tensor element that contains +// a wrapped `tensorflow::Tensor` to be copied between devices. +static Status WrappedTensorDeviceCopy( + const Tensor& from, Tensor* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { + if (DMAHelper::CanUseDMA(&from)) { + TF_RETURN_IF_ERROR(copy(from, to)); + } else { + *to = from; + } + + return Status::OK(); +} + +#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy) + +REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); +REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); +REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); + +} // namespace + } // namespace tensorflow |