aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-23 18:28:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 18:32:45 -0700
commit167272ead245ac9e0183da807d996ba9d6e401b0 (patch)
treeb90498bab003598622f9ac117fc1a4336c76ee76 /tensorflow/python/data
parent1f8db608007ae60f89bf38c4c6af98a0248f214e (diff)
[tf.data] Add `tf.contrib.data.Optional` support to `Structure`.
This change switches `tf.contrib.data.Optional` to use a `Structure` class to represent the structure of its value, instead of `output_types`, `output_shapes`, and `output_classes` properties. It adds support for nesting `Optional` objects and representing their structure. This change also makes a modification to the `Structure` class: `Structure.is_compatible_with(x)` now takes another `Structure` as the `x` argument, instead of a value. This makes it easier to work with nested structures (where we might not have a value readily available), and better matches the interface of other `is_compatible_with()` methods (e.g. in `tf.TensorShape` and `tf.DType`). Finally, in the process of making this change, I observed possible crash-failures when a DT_VARIANT tensor containing another DT_VARIANT tensor is copied between CPU and GPU. This change "fixes" the immediate problem by raising an UnimplementedError, but more work will be necessary to support the full range of use cases. PiperOrigin-RevId: 214198993
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py176
-rw-r--r--tensorflow/python/data/ops/BUILD5
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py13
-rw-r--r--tensorflow/python/data/ops/optional_ops.py150
-rw-r--r--tensorflow/python/data/util/structure.py131
-rw-r--r--tensorflow/python/data/util/structure_test.py36
7 files changed, 316 insertions, 196 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index f97116cadd..28ee3ebaa6 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -394,6 +394,7 @@ cuda_py_test(
size = "small",
srcs = ["optional_ops_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:optional_ops",
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index c344513e71..706a65fe55 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -17,11 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,14 +35,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase):
+class OptionalTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual(37.0, self.evaluate(opt.get_value()))
@@ -50,15 +49,6 @@ class OptionalTest(test.TestCase):
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
})
- self.assertEqual({
- "a": dtypes.float32,
- "b": (dtypes.string, dtypes.string)
- }, opt.output_types)
- self.assertEqual({"a": [], "b": ([1], [])}, opt.output_shapes)
- self.assertEqual({
- "a": ops.Tensor,
- "b": (ops.Tensor, ops.Tensor)
- }, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual({
"a": 37.0,
@@ -76,46 +66,29 @@ class OptionalTest(test.TestCase):
values=np.array([-1., 1.], dtype=np.float32),
dense_shape=np.array([2, 2]))
opt = optional_ops.Optional.from_value((st_0, st_1))
- self.assertEqual((dtypes.int64, dtypes.float32), opt.output_types)
- self.assertEqual(([1], [2, 2]), opt.output_shapes)
- self.assertEqual((sparse_tensor.SparseTensor, sparse_tensor.SparseTensor),
- opt.output_classes)
+ self.assertTrue(self.evaluate(opt.has_value()))
+ val_0, val_1 = opt.get_value()
+ for expected, actual in [(st_0, val_0), (st_1, val_1)]:
+ self.assertAllEqual(expected.indices, self.evaluate(actual.indices))
+ self.assertAllEqual(expected.values, self.evaluate(actual.values))
+ self.assertAllEqual(expected.dense_shape,
+ self.evaluate(actual.dense_shape))
@test_util.run_in_graph_and_eager_modes
def testFromNone(self):
- opt = optional_ops.Optional.none_from_structure(tensor_shape.scalar(),
- dtypes.float32, ops.Tensor)
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
+ value_structure = structure.TensorStructure(dtypes.float32, [])
+ opt = optional_ops.Optional.none_from_structure(value_structure)
+ self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.float32, [1])))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.int32, [])))
self.assertFalse(self.evaluate(opt.has_value()))
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(opt.get_value())
- def testStructureMismatchError(self):
- tuple_output_shapes = (tensor_shape.scalar(), tensor_shape.scalar())
- tuple_output_types = (dtypes.float32, dtypes.float32)
- tuple_output_classes = (ops.Tensor, ops.Tensor)
-
- dict_output_shapes = {
- "a": tensor_shape.scalar(),
- "b": tensor_shape.scalar()
- }
- dict_output_types = {"a": dtypes.float32, "b": dtypes.float32}
- dict_output_classes = {"a": ops.Tensor, "b": ops.Tensor}
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, tuple_output_types, dict_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, dict_output_types, tuple_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- dict_output_shapes, tuple_output_types, tuple_output_classes)
-
@test_util.run_in_graph_and_eager_modes
def testCopyToGPU(self):
if not test_util.is_gpu_available():
@@ -126,17 +99,15 @@ class OptionalTest(test.TestCase):
(constant_op.constant(37.0), constant_op.constant("Foo"),
constant_op.constant(42)))
optional_none = optional_ops.Optional.none_from_structure(
- tensor_shape.scalar(), dtypes.float32, ops.Tensor)
+ structure.TensorStructure(dtypes.float32, []))
with ops.device("/gpu:0"):
gpu_optional_with_value = optional_ops._OptionalImpl(
array_ops.identity(optional_with_value._variant_tensor),
- optional_with_value.output_shapes, optional_with_value.output_types,
- optional_with_value.output_classes)
+ optional_with_value.value_structure)
gpu_optional_none = optional_ops._OptionalImpl(
array_ops.identity(optional_none._variant_tensor),
- optional_none.output_shapes, optional_none.output_types,
- optional_none.output_classes)
+ optional_none.value_structure)
gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
gpu_optional_with_value_values = gpu_optional_with_value.get_value()
@@ -148,14 +119,101 @@ class OptionalTest(test.TestCase):
self.evaluate(gpu_optional_with_value_values))
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
- def testIteratorGetNextAsOptional(self):
- ds = dataset_ops.Dataset.range(3)
+ def _assertElementValueEqual(self, expected, actual):
+ if isinstance(expected, dict):
+ self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
+ for k in expected.keys():
+ self._assertElementValueEqual(expected[k], actual[k])
+ elif isinstance(expected, sparse_tensor.SparseTensorValue):
+ self.assertAllEqual(expected.indices, actual.indices)
+ self.assertAllEqual(expected.values, actual.values)
+ self.assertAllEqual(expected.dense_shape, actual.dense_shape)
+ else:
+ self.assertAllEqual(expected, actual)
+
+ # pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Tensor", lambda: constant_op.constant(37.0),
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", lambda: sparse_tensor.SparseTensor(
+ indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
+ dense_shape=[1]),
+ structure.SparseTensorStructure(dtypes.int32, [1])),
+ ("Nest", lambda: {
+ "a": constant_op.constant(37.0),
+ "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.TensorStructure(dtypes.string, [1]),
+ structure.TensorStructure(dtypes.string, []))})),
+ ("Optional", lambda: optional_ops.Optional.from_value(37.0),
+ optional_ops.OptionalStructure(
+ structure.TensorStructure(dtypes.float32, []))),
+ )
+ def testOptionalStructure(self, tf_value_fn, expected_value_structure):
+ tf_value = tf_value_fn()
+ opt = optional_ops.Optional.from_value(tf_value)
+
+ self.assertTrue(
+ expected_value_structure.is_compatible_with(opt.value_structure))
+ self.assertTrue(
+ opt.value_structure.is_compatible_with(expected_value_structure))
+
+ opt_structure = structure.Structure.from_value(opt)
+ self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
+ self.assertTrue(opt_structure.is_compatible_with(opt_structure))
+ self.assertTrue(opt_structure._value_structure.is_compatible_with(
+ expected_value_structure))
+ self.assertEqual([dtypes.variant], opt_structure._flat_types)
+ self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)
+
+ # All OptionalStructure objects are not compatible with a non-optional
+ # value.
+ non_optional_structure = structure.Structure.from_value(
+ constant_op.constant(42.0))
+ self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))
+
+ # Assert that the optional survives a round-trip via _from_tensor_list()
+ # and _to_tensor_list().
+ round_trip_opt = opt_structure._from_tensor_list(
+ opt_structure._to_tensor_list(opt))
+ if isinstance(tf_value, optional_ops.Optional):
+ self.assertEqual(
+ self.evaluate(tf_value.get_value()),
+ self.evaluate(round_trip_opt.get_value().get_value()))
+ else:
+ self.assertEqual(
+ self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
+
+ @parameterized.named_parameters(
+ ("Tensor", np.array([1, 2, 3], dtype=np.int32),
+ lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
+ ("SparseTensor", sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
+ False),
+ ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
+ "b": sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32),
+ dense_shape=[2, 2])},
+ lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
+ "b": sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
+ dense_shape=[2, 2])}, False),
+ )
+ def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu):
+ if not works_on_gpu and test.is_gpu_available():
+ self.skipTest("Test case not yet supported on GPU.")
+ ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
iterator = ds.make_initializable_iterator()
next_elem = iterator_ops.get_next_as_optional(iterator)
- self.assertTrue(isinstance(next_elem, optional_ops.Optional))
- self.assertEqual(ds.output_types, next_elem.output_types)
- self.assertEqual(ds.output_shapes, next_elem.output_shapes)
- self.assertEqual(ds.output_classes, next_elem.output_classes)
+ self.assertIsInstance(next_elem, optional_ops.Optional)
+ self.assertTrue(
+ next_elem.value_structure.is_compatible_with(
+ structure.Structure.from_value(tf_value_fn())))
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
with self.cached_session() as sess:
@@ -169,10 +227,10 @@ class OptionalTest(test.TestCase):
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
sess.run(iterator.initializer)
- for i in range(3):
+ for _ in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
- self.assertEqual(i, elem_value)
+ self._assertElementValueEqual(np_value, elem_value)
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 9dffc38820..76bf2470b1 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -64,6 +64,7 @@ py_library(
"//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
],
@@ -78,10 +79,8 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
],
)
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 8f8e026df9..cae00cdbfc 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -24,6 +24,7 @@ from tensorflow.python.compat import compat
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -85,10 +86,10 @@ class Iterator(checkpointable.CheckpointableBase):
initializer: A `tf.Operation` that should be run to initialize this
iterator.
output_types: A nested structure of `tf.DType` objects corresponding to
- each component of an element of this dataset.
+ each component of an element of this iterator.
output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of an element of this dataset.
- output_classes: A nested structure of Python `type` object corresponding
+ corresponding to each component of an element of this iterator.
+ output_classes: A nested structure of Python `type` objects corresponding
to each component of an element of this iterator.
"""
self._iterator_resource = iterator_resource
@@ -670,6 +671,6 @@ def get_next_as_optional(iterator):
output_shapes=nest.flatten(
sparse.as_dense_shapes(iterator.output_shapes,
iterator.output_classes))),
- output_shapes=iterator.output_shapes,
- output_types=iterator.output_types,
- output_classes=iterator.output_classes)
+ structure.Structure._from_legacy_structure(iterator.output_types,
+ iterator.output_shapes,
+ iterator.output_classes))
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index b75b98dc72..3bbebd7878 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -19,11 +19,9 @@ from __future__ import print_function
import abc
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
@@ -67,36 +65,14 @@ class Optional(object):
raise NotImplementedError("Optional.get_value()")
@abc.abstractproperty
- def output_classes(self):
- """Returns the class of each component of this optional.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_classes")
-
- @abc.abstractproperty
- def output_shapes(self):
- """Returns the shape of each component of this optional.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_shapes")
-
- @abc.abstractproperty
- def output_types(self):
- """Returns the type of each component of this optional.
+ def value_structure(self):
+ """The structure of the components of this optional.
Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of this optional.
+ A `Structure` object representing the structure of the components of this
+ optional.
"""
- raise NotImplementedError("Optional.output_types")
+ raise NotImplementedError("Optional.value_structure")
@staticmethod
def from_value(value):
@@ -108,48 +84,30 @@ class Optional(object):
Returns:
An `Optional` that wraps `value`.
"""
- # TODO(b/110122868): Consolidate this destructuring logic with the
- # similar code in `Dataset.from_tensors()`.
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
- value = nest.pack_sequence_as(value, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(value))
- ])
-
- encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
- output_classes = sparse.get_classes(value)
- output_shapes = nest.pack_sequence_as(
- value, [t.get_shape() for t in nest.flatten(value)])
- output_types = nest.pack_sequence_as(
- value, [t.dtype for t in nest.flatten(value)])
+ value_structure = structure.Structure.from_value(value)
+ encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access
return _OptionalImpl(
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
- output_shapes, output_types, output_classes)
+ value_structure)
@staticmethod
- def none_from_structure(output_shapes, output_types, output_classes):
+ def none_from_structure(value_structure):
"""Returns an `Optional` that has no value.
- NOTE: This method takes arguments that define the structure of the value
+ NOTE: This method takes an argument that defines the structure of the value
that would be contained in the returned `Optional` if it had a value.
Args:
- output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of this optional.
- output_types: A nested structure of `tf.DType` objects corresponding to
- each component of this optional.
- output_classes: A nested structure of Python `type` objects corresponding
- to each component of this optional.
+ value_structure: A `Structure` object representing the structure of the
+ components of this optional.
Returns:
An `Optional` that has no value.
"""
- return _OptionalImpl(gen_dataset_ops.optional_none(), output_shapes,
- output_types, output_classes)
+ return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure)
class _OptionalImpl(Optional):
@@ -159,20 +117,9 @@ class _OptionalImpl(Optional):
`Optional.__init__()` in the public API.
"""
- def __init__(self, variant_tensor, output_shapes, output_types,
- output_classes):
- # TODO(b/110122868): Consolidate the structure validation logic with the
- # similar logic in `Iterator.from_structure()` and
- # `Dataset.from_generator()`.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- nest.assert_same_structure(output_types, output_shapes)
- nest.assert_same_structure(output_types, output_classes)
+ def __init__(self, variant_tensor, value_structure):
self._variant_tensor = variant_tensor
- self._output_shapes = output_shapes
- self._output_types = output_types
- self._output_classes = output_classes
+ self._value_structure = value_structure
def has_value(self, name=None):
return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
@@ -182,28 +129,55 @@ class _OptionalImpl(Optional):
# in `Iterator.get_next()` and `StructuredFunctionWrapper`.
with ops.name_scope(name, "OptionalGetValue",
[self._variant_tensor]) as scope:
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(
- self._output_types,
- gen_dataset_ops.optional_get_value(
- self._variant_tensor,
- name=scope,
- output_types=nest.flatten(
- sparse.as_dense_types(self._output_types,
- self._output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self._output_shapes,
- self._output_classes)))),
- self._output_types, self._output_shapes, self._output_classes)
+ # pylint: disable=protected-access
+ return self._value_structure._from_tensor_list(
+ gen_dataset_ops.optional_get_value(
+ self._variant_tensor,
+ name=scope,
+ output_types=self._value_structure._flat_types,
+ output_shapes=self._value_structure._flat_shapes))
@property
- def output_classes(self):
- return self._output_classes
+ def value_structure(self):
+ return self._value_structure
+
+
+class OptionalStructure(structure.Structure):
+ """Represents an optional potentially containing a structured value."""
+
+ def __init__(self, value_structure):
+ self._value_structure = value_structure
@property
- def output_shapes(self):
- return self._output_shapes
+ def _flat_shapes(self):
+ return [tensor_shape.scalar()]
@property
- def output_types(self):
- return self._output_types
+ def _flat_types(self):
+ return [dtypes.variant]
+
+ def is_compatible_with(self, other):
+ # pylint: disable=protected-access
+ return (isinstance(other, OptionalStructure) and
+ self._value_structure.is_compatible_with(other._value_structure))
+
+ def _to_tensor_list(self, value):
+ return [value._variant_tensor] # pylint: disable=protected-access
+
+ def _from_tensor_list(self, flat_value):
+ if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
+ not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "OptionalStructure corresponds to a single tf.variant scalar.")
+ # pylint: disable=protected-access
+ return _OptionalImpl(flat_value[0], self._value_structure)
+
+ @staticmethod
+ def from_value(value):
+ return OptionalStructure(value.value_structure)
+
+
+# pylint: disable=protected-access
+structure.Structure._register_custom_converter(Optional,
+ OptionalStructure.from_value)
+# pylint: enable=protected-access
diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py
index c5764b8dfe..a90ca258c0 100644
--- a/tensorflow/python/data/util/structure.py
+++ b/tensorflow/python/data/util/structure.py
@@ -28,6 +28,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import sparse_ops
+_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {}
+
+
class Structure(object):
"""Represents structural information, such as type and shape, about a value.
@@ -64,12 +67,10 @@ class Structure(object):
raise NotImplementedError("Structure._flat_shapes")
@abc.abstractmethod
- def is_compatible_with(self, value):
- """Returns `True` if `value` is compatible with this structure.
+ def is_compatible_with(self, other):
+ """Returns `True` if `other` is compatible with this structure.
- A value `value` is compatible with a structure `s` if
- `Structure.from_value(value)` would return a structure `t` that is a
- "subtype" of `s`. A structure `t` is a "subtype" of `s` if:
+ A structure `t` is a "subtype" of `s` if:
* `s` and `t` are instances of the same `Structure` subclass.
* The nested structures (if any) of `s` and `t` are the same, according to
@@ -83,10 +84,10 @@ class Structure(object):
`tf.TensorShape.is_compatible_with`.
Args:
- value: A potentially structured value.
+ other: A `Structure`.
Returns:
- `True` if `value` matches this structure, otherwise `False`.
+ `True` if `other` is a subtype of this structure, otherwise `False`.
"""
raise NotImplementedError("Structure.is_compatible_with()")
@@ -98,7 +99,7 @@ class Structure(object):
`self._flat_types` to represent structured values in lower level APIs
(such as plain TensorFlow operations) that do not understand structure.
- Requires: `self.is_compatible_with(value)`.
+ Requires: `self.is_compatible_with(Structure.from_value(value))`.
Args:
value: A value with compatible structure.
@@ -137,9 +138,8 @@ class Structure(object):
TypeError: If a structure cannot be built for `value`, because its type
or one of its component types is not supported.
"""
-
- # TODO(b/110122868): Add support for custom types, Dataset, and Optional
- # to this method.
+ # TODO(b/110122868): Add support for custom types and Dataset to this
+ # method.
if isinstance(
value,
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
@@ -147,12 +147,76 @@ class Structure(object):
elif isinstance(value, (tuple, dict)):
return NestedStructure.from_value(value)
else:
+ for converter_type, converter_fn in (
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()):
+ if isinstance(value, converter_type):
+ return converter_fn(value)
try:
tensor = ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError("Could not build a structure for %r" % value)
return TensorStructure.from_value(tensor)
+ @staticmethod
+ def _from_legacy_structure(output_types, output_shapes, output_classes):
+ """Returns a `Structure` that represents the given legacy structure.
+
+ This method provides a way to convert from the existing `Dataset` and
+ `Iterator` structure-related properties to a `Structure` object.
+
+ TODO(b/110122868): Remove this method once `Structure` is used throughout
+ `tf.data`.
+
+ Args:
+ output_types: A nested structure of `tf.DType` objects corresponding to
+ each component of a structured value.
+ output_shapes: A nested structure of `tf.TensorShape` objects
+ corresponding to each component a structured value.
+ output_classes: A nested structure of Python `type` objects corresponding
+ to each component of a structured value.
+
+ Returns:
+ A `Structure`.
+
+ Raises:
+ TypeError: If a structure cannot be built the arguments, because one of
+ the component classes in `output_classes` is not supported.
+ """
+ flat_types = nest.flatten(output_types)
+ flat_shapes = nest.flatten(output_shapes)
+ flat_classes = nest.flatten(output_classes)
+ flat_ret = []
+ for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
+ flat_classes):
+ if issubclass(flat_class, sparse_tensor_lib.SparseTensor):
+ flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
+ elif issubclass(flat_class, ops.Tensor):
+ flat_ret.append(TensorStructure(flat_type, flat_shape))
+ else:
+ # NOTE(mrry): Since legacy structures produced by iterators only
+ # comprise Tensors, SparseTensors, and nests, we do not need to support
+ # all structure types here.
+ raise TypeError(
+ "Could not build a structure for output class %r" % flat_type)
+
+ ret = nest.pack_sequence_as(output_classes, flat_ret)
+ if isinstance(ret, Structure):
+ return ret
+ else:
+ return NestedStructure(ret)
+
+ @staticmethod
+ def _register_custom_converter(type_object, converter_fn):
+ """Registers `converter_fn` for converting values of the given type.
+
+ Args:
+ type_object: A Python `type` object representing the type of values
+ accepted by `converter_fn`.
+ converter_fn: A function that takes one argument (an instance of the
+ type represented by `type_object`) and returns a `Structure`.
+ """
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn
+
# NOTE(mrry): The following classes make extensive use of non-public methods of
# their base class, so we disable the protected-access lint warning once here.
@@ -179,16 +243,21 @@ class NestedStructure(Structure):
def _flat_types(self):
return self._flat_types_list
- def is_compatible_with(self, value):
+ def is_compatible_with(self, other):
+ if not isinstance(other, NestedStructure):
+ return False
try:
- nest.assert_shallow_structure(self._nested_structure, value)
+ # pylint: disable=protected-access
+ nest.assert_same_structure(self._nested_structure,
+ other._nested_structure)
except (ValueError, TypeError):
return False
return all(
- s.is_compatible_with(v) for s, v in zip(
+ substructure.is_compatible_with(other_substructure)
+ for substructure, other_substructure in zip(
nest.flatten(self._nested_structure),
- nest.flatten_up_to(self._nested_structure, value)))
+ nest.flatten(other._nested_structure)))
def _to_tensor_list(self, value):
ret = []
@@ -201,7 +270,7 @@ class NestedStructure(Structure):
for sub_value, structure in zip(flat_value,
nest.flatten(self._nested_structure)):
- if not structure.is_compatible_with(sub_value):
+ if not structure.is_compatible_with(Structure.from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(structure._to_tensor_list(sub_value))
@@ -242,17 +311,13 @@ class TensorStructure(Structure):
def _flat_types(self):
return [self._dtype]
- def is_compatible_with(self, value):
- try:
- value = ops.convert_to_tensor(value, dtype=self._dtype)
- except (ValueError, TypeError):
- return False
-
- return (self._dtype.is_compatible_with(value.dtype) and
- self._shape.is_compatible_with(value.shape))
+ def is_compatible_with(self, other):
+ return (isinstance(other, TensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._shape.is_compatible_with(other._shape))
def _to_tensor_list(self, value):
- if not self.is_compatible_with(value):
+ if not self.is_compatible_with(Structure.from_value(value)):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
return [value]
@@ -260,7 +325,7 @@ class TensorStructure(Structure):
def _from_tensor_list(self, flat_value):
if len(flat_value) != 1:
raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
- if not self.is_compatible_with(flat_value[0]):
+ if not self.is_compatible_with(Structure.from_value(flat_value[0])):
raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
"%s." % (flat_value[0], self._dtype, self._shape))
return flat_value[0]
@@ -285,16 +350,10 @@ class SparseTensorStructure(Structure):
def _flat_types(self):
return [dtypes.variant]
- def is_compatible_with(self, value):
- try:
- value = sparse_tensor_lib.SparseTensor.from_value(value)
- except TypeError:
- return False
- return (isinstance(value, (sparse_tensor_lib.SparseTensor,
- sparse_tensor_lib.SparseTensorValue)) and
- self._dtype.is_compatible_with(value.dtype) and
- self._dense_shape.is_compatible_with(
- tensor_util.constant_value_as_shape(value.dense_shape)))
+ def is_compatible_with(self, other):
+ return (isinstance(other, SparseTensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._dense_shape.is_compatible_with(other._dense_shape))
def _to_tensor_list(self, value):
return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index d0c7df67ae..2982763181 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -25,7 +25,9 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -106,13 +108,17 @@ class StructureTest(test.TestCase, parameterized.TestCase):
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
)
- def testIsCompatibleWith(self, original_value, compatible_values,
- incompatible_values):
+ def testIsCompatibleWithStructure(self, original_value, compatible_values,
+ incompatible_values):
s = structure.Structure.from_value(original_value)
for compatible_value in compatible_values:
- self.assertTrue(s.is_compatible_with(compatible_value))
+ self.assertTrue(
+ s.is_compatible_with(
+ structure.Structure.from_value(compatible_value)))
for incompatible_value in incompatible_values:
- self.assertFalse(s.is_compatible_with(incompatible_value))
+ self.assertFalse(
+ s.is_compatible_with(
+ structure.Structure.from_value(incompatible_value)))
# NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
# will be executed before the (eager- or graph-mode) test environment has been
@@ -322,6 +328,28 @@ class StructureTest(test.TestCase, parameterized.TestCase):
ValueError, "Expected 3 flat values in NestedStructure but got 2."):
s_2._from_tensor_list(flat_s_1)
+ @parameterized.named_parameters(
+ ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2),
+ sparse_tensor.SparseTensor,
+ structure.SparseTensorStructure(dtypes.int32, [2, 2])),
+ ("Nest",
+ {"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)},
+ {"a": tensor_shape.scalar(),
+ "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())},
+ {"a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor)},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
+ structure.TensorStructure(dtypes.string, []))})),
+ )
+ def testFromLegacyStructure(self, output_types, output_shapes, output_classes,
+ expected_structure):
+ actual_structure = structure.Structure._from_legacy_structure(
+ output_types, output_shapes, output_classes)
+ self.assertTrue(expected_structure.is_compatible_with(actual_structure))
+ self.assertTrue(actual_structure.is_compatible_with(expected_structure))
if __name__ == "__main__":
test.main()