aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-23 18:28:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 18:32:45 -0700
commit167272ead245ac9e0183da807d996ba9d6e401b0 (patch)
treeb90498bab003598622f9ac117fc1a4336c76ee76 /tensorflow/core/common_runtime
parent1f8db608007ae60f89bf38c4c6af98a0248f214e (diff)
[tf.data] Add `tf.contrib.data.Optional` support to `Structure`.
This change switches `tf.contrib.data.Optional` to use a `Structure` class to represent the structure of its value, instead of `output_types`, `output_shapes`, and `output_classes` properties. It adds support for nesting `Optional` objects and representing their structure. This change also makes a modification to the `Structure` class: `Structure.is_compatible_with(x)` now takes another `Structure` as the `x` argument, instead of a value. This makes it easier to work with nested structures (where we might not have a value readily available), and better matches the interface of other `is_compatible_with()` methods (e.g. in `tf.TensorShape` and `tf.DType`). Finally, in the process of making this change, I observed possible crash-failures when a DT_VARIANT tensor containing another DT_VARIANT tensor is copied between CPU and GPU. This change "fixes" the immediate problem by raising an UnimplementedError, but more work will be necessary to support the full range of use cases. PiperOrigin-RevId: 214198993
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index cf3d1f0b79..d800a86199 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -347,7 +347,12 @@ namespace {
static Status WrappedTensorDeviceCopy(
const Tensor& from, Tensor* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
- if (DMAHelper::CanUseDMA(&from)) {
+ if (from.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ } else if (DMAHelper::CanUseDMA(&from)) {
TF_RETURN_IF_ERROR(copy(from, to));
} else {
*to = from;