diff options
author | Derek Murray <mrry@google.com> | 2018-09-23 18:28:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-23 18:32:45 -0700 |
commit | 167272ead245ac9e0183da807d996ba9d6e401b0 (patch) | |
tree | b90498bab003598622f9ac117fc1a4336c76ee76 /tensorflow/core/common_runtime | |
parent | 1f8db608007ae60f89bf38c4c6af98a0248f214e (diff) |
[tf.data] Add `tf.contrib.data.Optional` support to `Structure`.
This change switches `tf.contrib.data.Optional` to use a `Structure` class to represent
the structure of its value, instead of `output_types`, `output_shapes`, and `output_classes` properties. It adds support for nesting `Optional` objects and representing their structure.
This change also makes a modification to the `Structure` class: `Structure.is_compatible_with(x)` now takes another `Structure` as the `x` argument, instead of a value. This makes it easier to work with nested structures (where we might not have a value readily available), and better matches the interface of other `is_compatible_with()` methods (e.g. in `tf.TensorShape` and `tf.DType`).
Finally, in the process of making this change, I observed possible crash-failures when a DT_VARIANT tensor containing another DT_VARIANT tensor is copied between CPU and GPU. This change "fixes" the immediate problem by raising an UnimplementedError, but more work will be necessary to support the full range of use cases.
PiperOrigin-RevId: 214198993
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r-- | tensorflow/core/common_runtime/copy_tensor.cc | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index cf3d1f0b79..d800a86199 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -347,7 +347,12 @@ namespace { static Status WrappedTensorDeviceCopy( const Tensor& from, Tensor* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { - if (DMAHelper::CanUseDMA(&from)) { + if (from.dtype() == DT_VARIANT) { + // TODO(b/116349787): Implement support for nested variants. + return errors::Unimplemented( + "Support for copying nested variants to device has not yet been " + "implemented."); + } else if (DMAHelper::CanUseDMA(&from)) { TF_RETURN_IF_ERROR(copy(from, to)); } else { *to = from; |