diff options
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 147 |
1 files changed, 135 insertions, 12 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 7cb6627615..88de4b588c 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -24,6 +24,7 @@ import warnings import numpy as np import six +from tensorflow.python.compat import compat from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed @@ -107,8 +108,12 @@ class Dataset(object): "execution is enabled.") if shared_name is None: shared_name = "" - iterator_resource = gen_dataset_ops.iterator( - container="", shared_name=shared_name, **flat_structure(self)) + if compat.forward_compatible(2018, 8, 3): + iterator_resource = gen_dataset_ops.iterator_v2( + container="", shared_name=shared_name, **flat_structure(self)) + else: + iterator_resource = gen_dataset_ops.iterator( + container="", shared_name=shared_name, **flat_structure(self)) with ops.colocate_with(iterator_resource): initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(), iterator_resource) @@ -888,7 +893,83 @@ class Dataset(object): drop_remainder) def map(self, map_func, num_parallel_calls=None): - """Maps `map_func` across this dataset. + """Maps `map_func` across the elements of this dataset. + + This transformation applies `map_func` to each element of this dataset, and + returns a new dataset containing the transformed elements, in the same + order as they appeared in the input. + + For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3, 4, 5 } + + a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 } + ``` + + The input signature of `map_func` is determined by the structure of each + element in this dataset. For example: + + ```python + # Each element is a `tf.Tensor` object. + a = { 1, 2, 3, 4, 5 } + # `map_func` takes a single argument of type `tf.Tensor` with the same + # shape and dtype. + result = a.map(lambda x: ...) + + # Each element is a tuple containing two `tf.Tensor` objects. + b = { (1, "foo"), (2, "bar"), (3, "baz") } + # `map_func` takes two arguments of type `tf.Tensor`. + result = b.map(lambda x_int, y_str: ...) + + # Each element is a dictionary mapping strings to `tf.Tensor` objects. + c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} } + # `map_func` takes a single argument of type `dict` with the same keys as + # the elements. + result = c.map(lambda d: ...) + ``` + + The value or values returned by `map_func` determine the structure of each + element in the returned dataset. + + ```python + # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`. + def f(...): + return tf.constant(37.0) + result = dataset.map(f) + result.output_classes == tf.Tensor + result.output_types == tf.float32 + result.output_shapes == [] # scalar + + # `map_func` returns two `tf.Tensor` objects. + def g(...): + return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) + result = dataset.map(g) + result.output_classes == (tf.Tensor, tf.Tensor) + result.output_types == (tf.float32, tf.string) + result.output_shapes == ([], [3]) + + # Python primitives, lists, and NumPy arrays are implicitly converted to + # `tf.Tensor`. + def h(...): + return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64) + result = dataset.map(h) + result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor) + result.output_types == (tf.float32, tf.string, tf.float64) + result.output_shapes == ([], [3], [2]) + + # `map_func` can return nested structures. + def i(...): + return {"a": 37.0, "b": [42, 16]}, "foo" + result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor) + result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string) + result.output_shapes == ({"a": [], "b": [2]}, []) + ``` + + In addition to `tf.Tensor` objects, `map_func` can accept as arguments and + return `tf.SparseTensor` objects. Args: map_func: A function mapping a nested structure of tensors (having @@ -1168,10 +1249,29 @@ class _NestedDatasetComponent(object): custom component types. """ - def __init__(self, dataset): - self._output_classes = dataset.output_classes - self._output_shapes = dataset.output_shapes - self._output_types = dataset.output_types + def __init__(self, + dataset=None, + output_shapes=None, + output_types=None, + output_classes=None): + if dataset is None: + if (output_classes is None or output_shapes is None or + output_types is None): + raise ValueError( + "Either `dataset`, or all of `output_classes`, " + "`output_shapes`, and `output_types` must be specified.") + self._output_classes = output_classes + self._output_shapes = output_shapes + self._output_types = output_types + else: + if not (output_classes is None and output_shapes is None and + output_types is None): + raise ValueError( + "Either `dataset`, or all of `output_classes`, " + "`output_shapes`, and `output_types` must be specified.") + self._output_classes = dataset.output_classes + self._output_shapes = dataset.output_shapes + self._output_types = dataset.output_types @property def output_classes(self): @@ -1330,7 +1430,11 @@ class StructuredFunctionWrapper(object): flat_shapes.append(component) flat_types.append(component) else: - t = ops.convert_to_tensor(t) + try: + t = ops.convert_to_tensor(t) + except (ValueError, TypeError): + raise TypeError("Unsupported return value from function passed to " + "%s: %s." % (transformation_name, t)) flat_ret.append(t) flat_classes.append(ops.Tensor) flat_shapes.append(t.get_shape()) @@ -1406,11 +1510,30 @@ def flat_structure(dataset): A dictionary of keyword arguments that can be passed to many Dataset op constructors. """ + output_classes = [] + output_shapes = [] + output_types = [] + for output_class, output_shape, output_type in zip( + nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_types)): + if isinstance(output_class, _NestedDatasetComponent): + output_classes.append(output_class.output_classes) + output_shapes.append(output_shape.output_shapes) + output_types.append(output_type.output_types) + else: + output_classes.append(output_class) + output_shapes.append(output_shape) + output_types.append(output_type) + + output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes) + output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes) + output_types = nest.pack_sequence_as(dataset.output_types, output_types) + return { - "output_shapes": nest.flatten(sparse.as_dense_shapes( - dataset.output_shapes, dataset.output_classes)), - "output_types": nest.flatten(sparse.as_dense_types( - dataset.output_types, dataset.output_classes)), + "output_shapes": + nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)), + "output_types": + nest.flatten(sparse.as_dense_types(output_types, output_classes)), } |