diff options
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/optional_ops_test.py | 176 | ||||
-rw-r--r-- | tensorflow/python/data/ops/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/python/data/ops/iterator_ops.py | 13 | ||||
-rw-r--r-- | tensorflow/python/data/ops/optional_ops.py | 150 | ||||
-rw-r--r-- | tensorflow/python/data/util/structure.py | 131 | ||||
-rw-r--r-- | tensorflow/python/data/util/structure_test.py | 36 |
7 files changed, 316 insertions, 196 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index f97116cadd..28ee3ebaa6 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -394,6 +394,7 @@ cuda_py_test( size = "small", srcs = ["optional_ops_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:optional_ops", diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py index c344513e71..706a65fe55 100644 --- a/tensorflow/python/data/kernel_tests/optional_ops_test.py +++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py @@ -17,11 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import optional_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -33,14 +35,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class OptionalTest(test.TestCase): +class OptionalTest(test.TestCase, parameterized.TestCase): @test_util.run_in_graph_and_eager_modes def testFromValue(self): opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) - self.assertEqual(dtypes.float32, opt.output_types) - self.assertEqual([], opt.output_shapes) - self.assertEqual(ops.Tensor, opt.output_classes) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual(37.0, self.evaluate(opt.get_value())) @@ -50,15 +49,6 @@ class OptionalTest(test.TestCase): "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }) - self.assertEqual({ - "a": dtypes.float32, - "b": (dtypes.string, dtypes.string) - }, opt.output_types) - self.assertEqual({"a": [], "b": ([1], [])}, opt.output_shapes) - self.assertEqual({ - "a": ops.Tensor, - "b": (ops.Tensor, ops.Tensor) - }, opt.output_classes) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual({ "a": 37.0, @@ -76,46 +66,29 @@ class OptionalTest(test.TestCase): values=np.array([-1., 1.], dtype=np.float32), dense_shape=np.array([2, 2])) opt = optional_ops.Optional.from_value((st_0, st_1)) - self.assertEqual((dtypes.int64, dtypes.float32), opt.output_types) - self.assertEqual(([1], [2, 2]), opt.output_shapes) - self.assertEqual((sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), - opt.output_classes) + self.assertTrue(self.evaluate(opt.has_value())) + val_0, val_1 = opt.get_value() + for expected, actual in [(st_0, val_0), (st_1, val_1)]: + self.assertAllEqual(expected.indices, self.evaluate(actual.indices)) + self.assertAllEqual(expected.values, self.evaluate(actual.values)) + self.assertAllEqual(expected.dense_shape, + self.evaluate(actual.dense_shape)) @test_util.run_in_graph_and_eager_modes def testFromNone(self): - opt = optional_ops.Optional.none_from_structure(tensor_shape.scalar(), - dtypes.float32, ops.Tensor) - self.assertEqual(dtypes.float32, opt.output_types) - self.assertEqual([], opt.output_shapes) - self.assertEqual(ops.Tensor, opt.output_classes) + value_structure = structure.TensorStructure(dtypes.float32, []) + opt = optional_ops.Optional.none_from_structure(value_structure) + self.assertTrue(opt.value_structure.is_compatible_with(value_structure)) + self.assertFalse( + opt.value_structure.is_compatible_with( + structure.TensorStructure(dtypes.float32, [1]))) + self.assertFalse( + opt.value_structure.is_compatible_with( + structure.TensorStructure(dtypes.int32, []))) self.assertFalse(self.evaluate(opt.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value()) - def testStructureMismatchError(self): - tuple_output_shapes = (tensor_shape.scalar(), tensor_shape.scalar()) - tuple_output_types = (dtypes.float32, dtypes.float32) - tuple_output_classes = (ops.Tensor, ops.Tensor) - - dict_output_shapes = { - "a": tensor_shape.scalar(), - "b": tensor_shape.scalar() - } - dict_output_types = {"a": dtypes.float32, "b": dtypes.float32} - dict_output_classes = {"a": ops.Tensor, "b": ops.Tensor} - - with self.assertRaises(TypeError): - optional_ops.Optional.none_from_structure( - tuple_output_shapes, tuple_output_types, dict_output_classes) - - with self.assertRaises(TypeError): - optional_ops.Optional.none_from_structure( - tuple_output_shapes, dict_output_types, tuple_output_classes) - - with self.assertRaises(TypeError): - optional_ops.Optional.none_from_structure( - dict_output_shapes, tuple_output_types, tuple_output_classes) - @test_util.run_in_graph_and_eager_modes def testCopyToGPU(self): if not test_util.is_gpu_available(): @@ -126,17 +99,15 @@ class OptionalTest(test.TestCase): (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( - tensor_shape.scalar(), dtypes.float32, ops.Tensor) + structure.TensorStructure(dtypes.float32, [])) with ops.device("/gpu:0"): gpu_optional_with_value = optional_ops._OptionalImpl( array_ops.identity(optional_with_value._variant_tensor), - optional_with_value.output_shapes, optional_with_value.output_types, - optional_with_value.output_classes) + optional_with_value.value_structure) gpu_optional_none = optional_ops._OptionalImpl( array_ops.identity(optional_none._variant_tensor), - optional_none.output_shapes, optional_none.output_types, - optional_none.output_classes) + optional_none.value_structure) gpu_optional_with_value_has_value = gpu_optional_with_value.has_value() gpu_optional_with_value_values = gpu_optional_with_value.get_value() @@ -148,14 +119,101 @@ class OptionalTest(test.TestCase): self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value)) - def testIteratorGetNextAsOptional(self): - ds = dataset_ops.Dataset.range(3) + def _assertElementValueEqual(self, expected, actual): + if isinstance(expected, dict): + self.assertItemsEqual(list(expected.keys()), list(actual.keys())) + for k in expected.keys(): + self._assertElementValueEqual(expected[k], actual[k]) + elif isinstance(expected, sparse_tensor.SparseTensorValue): + self.assertAllEqual(expected.indices, actual.indices) + self.assertAllEqual(expected.values, actual.values) + self.assertAllEqual(expected.dense_shape, actual.dense_shape) + else: + self.assertAllEqual(expected, actual) + + # pylint: disable=g-long-lambda + @parameterized.named_parameters( + ("Tensor", lambda: constant_op.constant(37.0), + structure.TensorStructure(dtypes.float32, [])), + ("SparseTensor", lambda: sparse_tensor.SparseTensor( + indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), + dense_shape=[1]), + structure.SparseTensorStructure(dtypes.int32, [1])), + ("Nest", lambda: { + "a": constant_op.constant(37.0), + "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, + structure.NestedStructure({ + "a": structure.TensorStructure(dtypes.float32, []), + "b": (structure.TensorStructure(dtypes.string, [1]), + structure.TensorStructure(dtypes.string, []))})), + ("Optional", lambda: optional_ops.Optional.from_value(37.0), + optional_ops.OptionalStructure( + structure.TensorStructure(dtypes.float32, []))), + ) + def testOptionalStructure(self, tf_value_fn, expected_value_structure): + tf_value = tf_value_fn() + opt = optional_ops.Optional.from_value(tf_value) + + self.assertTrue( + expected_value_structure.is_compatible_with(opt.value_structure)) + self.assertTrue( + opt.value_structure.is_compatible_with(expected_value_structure)) + + opt_structure = structure.Structure.from_value(opt) + self.assertIsInstance(opt_structure, optional_ops.OptionalStructure) + self.assertTrue(opt_structure.is_compatible_with(opt_structure)) + self.assertTrue(opt_structure._value_structure.is_compatible_with( + expected_value_structure)) + self.assertEqual([dtypes.variant], opt_structure._flat_types) + self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes) + + # All OptionalStructure objects are not compatible with a non-optional + # value. + non_optional_structure = structure.Structure.from_value( + constant_op.constant(42.0)) + self.assertFalse(opt_structure.is_compatible_with(non_optional_structure)) + + # Assert that the optional survives a round-trip via _from_tensor_list() + # and _to_tensor_list(). + round_trip_opt = opt_structure._from_tensor_list( + opt_structure._to_tensor_list(opt)) + if isinstance(tf_value, optional_ops.Optional): + self.assertEqual( + self.evaluate(tf_value.get_value()), + self.evaluate(round_trip_opt.get_value().get_value())) + else: + self.assertEqual( + self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value())) + + @parameterized.named_parameters( + ("Tensor", np.array([1, 2, 3], dtype=np.int32), + lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True), + ("SparseTensor", sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], + values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]), + lambda: sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]), + False), + ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32), + "b": sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], + values=np.array([-1., 1.], dtype=np.float32), + dense_shape=[2, 2])}, + lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32), + "b": sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 0]], values=[37.0, 42.0], + dense_shape=[2, 2])}, False), + ) + def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu): + if not works_on_gpu and test.is_gpu_available(): + self.skipTest("Test case not yet supported on GPU.") + ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) iterator = ds.make_initializable_iterator() next_elem = iterator_ops.get_next_as_optional(iterator) - self.assertTrue(isinstance(next_elem, optional_ops.Optional)) - self.assertEqual(ds.output_types, next_elem.output_types) - self.assertEqual(ds.output_shapes, next_elem.output_shapes) - self.assertEqual(ds.output_classes, next_elem.output_classes) + self.assertIsInstance(next_elem, optional_ops.Optional) + self.assertTrue( + next_elem.value_structure.is_compatible_with( + structure.Structure.from_value(tf_value_fn()))) elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() with self.cached_session() as sess: @@ -169,10 +227,10 @@ class OptionalTest(test.TestCase): # For each element of the dataset, assert that the optional evaluates to # the expected value. sess.run(iterator.initializer) - for i in range(3): + for _ in range(3): elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) - self.assertEqual(i, elem_value) + self._assertElementValueEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 9dffc38820..76bf2470b1 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -64,6 +64,7 @@ py_library( "//tensorflow/python/compat", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", "//tensorflow/python/eager:context", "//tensorflow/python/training/checkpointable:base", ], @@ -78,10 +79,8 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 8f8e026df9..cae00cdbfc 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.compat import compat from tensorflow.python.data.ops import optional_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.data.util import structure from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -85,10 +86,10 @@ class Iterator(checkpointable.CheckpointableBase): initializer: A `tf.Operation` that should be run to initialize this iterator. output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element of this dataset. + each component of an element of this iterator. output_shapes: A nested structure of `tf.TensorShape` objects - corresponding to each component of an element of this dataset. - output_classes: A nested structure of Python `type` object corresponding + corresponding to each component of an element of this iterator. + output_classes: A nested structure of Python `type` objects corresponding to each component of an element of this iterator. """ self._iterator_resource = iterator_resource @@ -670,6 +671,6 @@ def get_next_as_optional(iterator): output_shapes=nest.flatten( sparse.as_dense_shapes(iterator.output_shapes, iterator.output_classes))), - output_shapes=iterator.output_shapes, - output_types=iterator.output_types, - output_classes=iterator.output_classes) + structure.Structure._from_legacy_structure(iterator.output_types, + iterator.output_shapes, + iterator.output_classes)) diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py index b75b98dc72..3bbebd7878 100644 --- a/tensorflow/python/data/ops/optional_ops.py +++ b/tensorflow/python/data/ops/optional_ops.py @@ -19,11 +19,9 @@ from __future__ import print_function import abc -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops @@ -67,36 +65,14 @@ class Optional(object): raise NotImplementedError("Optional.get_value()") @abc.abstractproperty - def output_classes(self): - """Returns the class of each component of this optional. - - The expected values are `tf.Tensor` and `tf.SparseTensor`. - - Returns: - A nested structure of Python `type` objects corresponding to each - component of this optional. - """ - raise NotImplementedError("Optional.output_classes") - - @abc.abstractproperty - def output_shapes(self): - """Returns the shape of each component of this optional. - - Returns: - A nested structure of `tf.TensorShape` objects corresponding to each - component of this optional. - """ - raise NotImplementedError("Optional.output_shapes") - - @abc.abstractproperty - def output_types(self): - """Returns the type of each component of this optional. + def value_structure(self): + """The structure of the components of this optional. Returns: - A nested structure of `tf.DType` objects corresponding to each component - of this optional. + A `Structure` object representing the structure of the components of this + optional. """ - raise NotImplementedError("Optional.output_types") + raise NotImplementedError("Optional.value_structure") @staticmethod def from_value(value): @@ -108,48 +84,30 @@ class Optional(object): Returns: An `Optional` that wraps `value`. """ - # TODO(b/110122868): Consolidate this destructuring logic with the - # similar code in `Dataset.from_tensors()`. with ops.name_scope("optional") as scope: with ops.name_scope("value"): - value = nest.pack_sequence_as(value, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( - t, name="component_%d" % i) - for i, t in enumerate(nest.flatten(value)) - ]) - - encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value)) - output_classes = sparse.get_classes(value) - output_shapes = nest.pack_sequence_as( - value, [t.get_shape() for t in nest.flatten(value)]) - output_types = nest.pack_sequence_as( - value, [t.dtype for t in nest.flatten(value)]) + value_structure = structure.Structure.from_value(value) + encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), - output_shapes, output_types, output_classes) + value_structure) @staticmethod - def none_from_structure(output_shapes, output_types, output_classes): + def none_from_structure(value_structure): """Returns an `Optional` that has no value. - NOTE: This method takes arguments that define the structure of the value + NOTE: This method takes an argument that defines the structure of the value that would be contained in the returned `Optional` if it had a value. Args: - output_shapes: A nested structure of `tf.TensorShape` objects - corresponding to each component of this optional. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of this optional. - output_classes: A nested structure of Python `type` objects corresponding - to each component of this optional. + value_structure: A `Structure` object representing the structure of the + components of this optional. Returns: An `Optional` that has no value. """ - return _OptionalImpl(gen_dataset_ops.optional_none(), output_shapes, - output_types, output_classes) + return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure) class _OptionalImpl(Optional): @@ -159,20 +117,9 @@ class _OptionalImpl(Optional): `Optional.__init__()` in the public API. """ - def __init__(self, variant_tensor, output_shapes, output_types, - output_classes): - # TODO(b/110122868): Consolidate the structure validation logic with the - # similar logic in `Iterator.from_structure()` and - # `Dataset.from_generator()`. - output_types = nest.map_structure(dtypes.as_dtype, output_types) - output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - nest.assert_same_structure(output_types, output_shapes) - nest.assert_same_structure(output_types, output_classes) + def __init__(self, variant_tensor, value_structure): self._variant_tensor = variant_tensor - self._output_shapes = output_shapes - self._output_types = output_types - self._output_classes = output_classes + self._value_structure = value_structure def has_value(self, name=None): return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name) @@ -182,28 +129,55 @@ class _OptionalImpl(Optional): # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as( - self._output_types, - gen_dataset_ops.optional_get_value( - self._variant_tensor, - name=scope, - output_types=nest.flatten( - sparse.as_dense_types(self._output_types, - self._output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self._output_shapes, - self._output_classes)))), - self._output_types, self._output_shapes, self._output_classes) + # pylint: disable=protected-access + return self._value_structure._from_tensor_list( + gen_dataset_ops.optional_get_value( + self._variant_tensor, + name=scope, + output_types=self._value_structure._flat_types, + output_shapes=self._value_structure._flat_shapes)) @property - def output_classes(self): - return self._output_classes + def value_structure(self): + return self._value_structure + + +class OptionalStructure(structure.Structure): + """Represents an optional potentially containing a structured value.""" + + def __init__(self, value_structure): + self._value_structure = value_structure @property - def output_shapes(self): - return self._output_shapes + def _flat_shapes(self): + return [tensor_shape.scalar()] @property - def output_types(self): - return self._output_types + def _flat_types(self): + return [dtypes.variant] + + def is_compatible_with(self, other): + # pylint: disable=protected-access + return (isinstance(other, OptionalStructure) and + self._value_structure.is_compatible_with(other._value_structure)) + + def _to_tensor_list(self, value): + return [value._variant_tensor] # pylint: disable=protected-access + + def _from_tensor_list(self, flat_value): + if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or + not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "OptionalStructure corresponds to a single tf.variant scalar.") + # pylint: disable=protected-access + return _OptionalImpl(flat_value[0], self._value_structure) + + @staticmethod + def from_value(value): + return OptionalStructure(value.value_structure) + + +# pylint: disable=protected-access +structure.Structure._register_custom_converter(Optional, + OptionalStructure.from_value) +# pylint: enable=protected-access diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index c5764b8dfe..a90ca258c0 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -28,6 +28,9 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import sparse_ops +_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {} + + class Structure(object): """Represents structural information, such as type and shape, about a value. @@ -64,12 +67,10 @@ class Structure(object): raise NotImplementedError("Structure._flat_shapes") @abc.abstractmethod - def is_compatible_with(self, value): - """Returns `True` if `value` is compatible with this structure. + def is_compatible_with(self, other): + """Returns `True` if `other` is compatible with this structure. - A value `value` is compatible with a structure `s` if - `Structure.from_value(value)` would return a structure `t` that is a - "subtype" of `s`. A structure `t` is a "subtype" of `s` if: + A structure `t` is a "subtype" of `s` if: * `s` and `t` are instances of the same `Structure` subclass. * The nested structures (if any) of `s` and `t` are the same, according to @@ -83,10 +84,10 @@ class Structure(object): `tf.TensorShape.is_compatible_with`. Args: - value: A potentially structured value. + other: A `Structure`. Returns: - `True` if `value` matches this structure, otherwise `False`. + `True` if `other` is a subtype of this structure, otherwise `False`. """ raise NotImplementedError("Structure.is_compatible_with()") @@ -98,7 +99,7 @@ class Structure(object): `self._flat_types` to represent structured values in lower level APIs (such as plain TensorFlow operations) that do not understand structure. - Requires: `self.is_compatible_with(value)`. + Requires: `self.is_compatible_with(Structure.from_value(value))`. Args: value: A value with compatible structure. @@ -137,9 +138,8 @@ class Structure(object): TypeError: If a structure cannot be built for `value`, because its type or one of its component types is not supported. """ - - # TODO(b/110122868): Add support for custom types, Dataset, and Optional - # to this method. + # TODO(b/110122868): Add support for custom types and Dataset to this + # method. if isinstance( value, (sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)): @@ -147,12 +147,76 @@ class Structure(object): elif isinstance(value, (tuple, dict)): return NestedStructure.from_value(value) else: + for converter_type, converter_fn in ( + _STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()): + if isinstance(value, converter_type): + return converter_fn(value) try: tensor = ops.convert_to_tensor(value) except (ValueError, TypeError): raise TypeError("Could not build a structure for %r" % value) return TensorStructure.from_value(tensor) + @staticmethod + def _from_legacy_structure(output_types, output_shapes, output_classes): + """Returns a `Structure` that represents the given legacy structure. + + This method provides a way to convert from the existing `Dataset` and + `Iterator` structure-related properties to a `Structure` object. + + TODO(b/110122868): Remove this method once `Structure` is used throughout + `tf.data`. + + Args: + output_types: A nested structure of `tf.DType` objects corresponding to + each component of a structured value. + output_shapes: A nested structure of `tf.TensorShape` objects + corresponding to each component a structured value. + output_classes: A nested structure of Python `type` objects corresponding + to each component of a structured value. + + Returns: + A `Structure`. + + Raises: + TypeError: If a structure cannot be built the arguments, because one of + the component classes in `output_classes` is not supported. + """ + flat_types = nest.flatten(output_types) + flat_shapes = nest.flatten(output_shapes) + flat_classes = nest.flatten(output_classes) + flat_ret = [] + for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, + flat_classes): + if issubclass(flat_class, sparse_tensor_lib.SparseTensor): + flat_ret.append(SparseTensorStructure(flat_type, flat_shape)) + elif issubclass(flat_class, ops.Tensor): + flat_ret.append(TensorStructure(flat_type, flat_shape)) + else: + # NOTE(mrry): Since legacy structures produced by iterators only + # comprise Tensors, SparseTensors, and nests, we do not need to support + # all structure types here. + raise TypeError( + "Could not build a structure for output class %r" % flat_type) + + ret = nest.pack_sequence_as(output_classes, flat_ret) + if isinstance(ret, Structure): + return ret + else: + return NestedStructure(ret) + + @staticmethod + def _register_custom_converter(type_object, converter_fn): + """Registers `converter_fn` for converting values of the given type. + + Args: + type_object: A Python `type` object representing the type of values + accepted by `converter_fn`. + converter_fn: A function that takes one argument (an instance of the + type represented by `type_object`) and returns a `Structure`. + """ + _STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn + # NOTE(mrry): The following classes make extensive use of non-public methods of # their base class, so we disable the protected-access lint warning once here. @@ -179,16 +243,21 @@ class NestedStructure(Structure): def _flat_types(self): return self._flat_types_list - def is_compatible_with(self, value): + def is_compatible_with(self, other): + if not isinstance(other, NestedStructure): + return False try: - nest.assert_shallow_structure(self._nested_structure, value) + # pylint: disable=protected-access + nest.assert_same_structure(self._nested_structure, + other._nested_structure) except (ValueError, TypeError): return False return all( - s.is_compatible_with(v) for s, v in zip( + substructure.is_compatible_with(other_substructure) + for substructure, other_substructure in zip( nest.flatten(self._nested_structure), - nest.flatten_up_to(self._nested_structure, value))) + nest.flatten(other._nested_structure))) def _to_tensor_list(self, value): ret = [] @@ -201,7 +270,7 @@ class NestedStructure(Structure): for sub_value, structure in zip(flat_value, nest.flatten(self._nested_structure)): - if not structure.is_compatible_with(sub_value): + if not structure.is_compatible_with(Structure.from_value(sub_value)): raise ValueError("Component value %r is not compatible with the nested " "structure %r." % (sub_value, structure)) ret.extend(structure._to_tensor_list(sub_value)) @@ -242,17 +311,13 @@ class TensorStructure(Structure): def _flat_types(self): return [self._dtype] - def is_compatible_with(self, value): - try: - value = ops.convert_to_tensor(value, dtype=self._dtype) - except (ValueError, TypeError): - return False - - return (self._dtype.is_compatible_with(value.dtype) and - self._shape.is_compatible_with(value.shape)) + def is_compatible_with(self, other): + return (isinstance(other, TensorStructure) and + self._dtype.is_compatible_with(other._dtype) and + self._shape.is_compatible_with(other._shape)) def _to_tensor_list(self, value): - if not self.is_compatible_with(value): + if not self.is_compatible_with(Structure.from_value(value)): raise ValueError("Value %r is not convertible to a tensor with dtype %s " "and shape %s." % (value, self._dtype, self._shape)) return [value] @@ -260,7 +325,7 @@ class TensorStructure(Structure): def _from_tensor_list(self, flat_value): if len(flat_value) != 1: raise ValueError("TensorStructure corresponds to a single tf.Tensor.") - if not self.is_compatible_with(flat_value[0]): + if not self.is_compatible_with(Structure.from_value(flat_value[0])): raise ValueError("Cannot convert %r to a tensor with dtype %s and shape " "%s." % (flat_value[0], self._dtype, self._shape)) return flat_value[0] @@ -285,16 +350,10 @@ class SparseTensorStructure(Structure): def _flat_types(self): return [dtypes.variant] - def is_compatible_with(self, value): - try: - value = sparse_tensor_lib.SparseTensor.from_value(value) - except TypeError: - return False - return (isinstance(value, (sparse_tensor_lib.SparseTensor, - sparse_tensor_lib.SparseTensorValue)) and - self._dtype.is_compatible_with(value.dtype) and - self._dense_shape.is_compatible_with( - tensor_util.constant_value_as_shape(value.dense_shape))) + def is_compatible_with(self, other): + return (isinstance(other, SparseTensorStructure) and + self._dtype.is_compatible_with(other._dtype) and + self._dense_shape.is_compatible_with(other._dense_shape)) def _to_tensor_list(self, value): return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)] diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index d0c7df67ae..2982763181 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -25,7 +25,9 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import structure from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -106,13 +108,17 @@ class StructureTest(test.TestCase, parameterized.TestCase): indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) - def testIsCompatibleWith(self, original_value, compatible_values, - incompatible_values): + def testIsCompatibleWithStructure(self, original_value, compatible_values, + incompatible_values): s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: - self.assertTrue(s.is_compatible_with(compatible_value)) + self.assertTrue( + s.is_compatible_with( + structure.Structure.from_value(compatible_value))) for incompatible_value in incompatible_values: - self.assertFalse(s.is_compatible_with(incompatible_value)) + self.assertFalse( + s.is_compatible_with( + structure.Structure.from_value(incompatible_value))) # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they # will be executed before the (eager- or graph-mode) test environment has been @@ -322,6 +328,28 @@ class StructureTest(test.TestCase, parameterized.TestCase): ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_1) + @parameterized.named_parameters( + ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, + structure.TensorStructure(dtypes.float32, [])), + ("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2), + sparse_tensor.SparseTensor, + structure.SparseTensorStructure(dtypes.int32, [2, 2])), + ("Nest", + {"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)}, + {"a": tensor_shape.scalar(), + "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())}, + {"a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor)}, + structure.NestedStructure({ + "a": structure.TensorStructure(dtypes.float32, []), + "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), + structure.TensorStructure(dtypes.string, []))})), + ) + def testFromLegacyStructure(self, output_types, output_shapes, output_classes, + expected_structure): + actual_structure = structure.Structure._from_legacy_structure( + output_types, output_shapes, output_classes) + self.assertTrue(expected_structure.is_compatible_with(actual_structure)) + self.assertTrue(actual_structure.is_compatible_with(expected_structure)) if __name__ == "__main__": test.main() |