aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/ops/dataset_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py147
1 files changed, 135 insertions, 12 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 7cb6627615..88de4b588c 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -24,6 +24,7 @@ import warnings
import numpy as np
import six
+from tensorflow.python.compat import compat
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
@@ -107,8 +108,12 @@ class Dataset(object):
"execution is enabled.")
if shared_name is None:
shared_name = ""
- iterator_resource = gen_dataset_ops.iterator(
- container="", shared_name=shared_name, **flat_structure(self))
+ if compat.forward_compatible(2018, 8, 3):
+ iterator_resource = gen_dataset_ops.iterator_v2(
+ container="", shared_name=shared_name, **flat_structure(self))
+ else:
+ iterator_resource = gen_dataset_ops.iterator(
+ container="", shared_name=shared_name, **flat_structure(self))
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
iterator_resource)
@@ -888,7 +893,83 @@ class Dataset(object):
drop_remainder)
def map(self, map_func, num_parallel_calls=None):
- """Maps `map_func` across this dataset.
+ """Maps `map_func` across the elements of this dataset.
+
+ This transformation applies `map_func` to each element of this dataset, and
+ returns a new dataset containing the transformed elements, in the same
+ order as they appeared in the input.
+
+ For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { 1, 2, 3, 4, 5 }
+
+ a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 }
+ ```
+
+ The input signature of `map_func` is determined by the structure of each
+ element in this dataset. For example:
+
+ ```python
+ # Each element is a `tf.Tensor` object.
+ a = { 1, 2, 3, 4, 5 }
+ # `map_func` takes a single argument of type `tf.Tensor` with the same
+ # shape and dtype.
+ result = a.map(lambda x: ...)
+
+ # Each element is a tuple containing two `tf.Tensor` objects.
+ b = { (1, "foo"), (2, "bar"), (3, "baz") }
+ # `map_func` takes two arguments of type `tf.Tensor`.
+ result = b.map(lambda x_int, y_str: ...)
+
+ # Each element is a dictionary mapping strings to `tf.Tensor` objects.
+ c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
+ # `map_func` takes a single argument of type `dict` with the same keys as
+ # the elements.
+ result = c.map(lambda d: ...)
+ ```
+
+ The value or values returned by `map_func` determine the structure of each
+ element in the returned dataset.
+
+ ```python
+ # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
+ def f(...):
+ return tf.constant(37.0)
+ result = dataset.map(f)
+ result.output_classes == tf.Tensor
+ result.output_types == tf.float32
+ result.output_shapes == [] # scalar
+
+ # `map_func` returns two `tf.Tensor` objects.
+ def g(...):
+ return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
+ result = dataset.map(g)
+ result.output_classes == (tf.Tensor, tf.Tensor)
+ result.output_types == (tf.float32, tf.string)
+ result.output_shapes == ([], [3])
+
+ # Python primitives, lists, and NumPy arrays are implicitly converted to
+ # `tf.Tensor`.
+ def h(...):
+ return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
+ result = dataset.map(h)
+ result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
+ result.output_types == (tf.float32, tf.string, tf.float64)
+ result.output_shapes == ([], [3], [2])
+
+ # `map_func` can return nested structures.
+ def i(...):
+ return {"a": 37.0, "b": [42, 16]}, "foo"
+ result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
+ result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
+ result.output_shapes == ({"a": [], "b": [2]}, [])
+ ```
+
+ In addition to `tf.Tensor` objects, `map_func` can accept as arguments and
+ return `tf.SparseTensor` objects.
Args:
map_func: A function mapping a nested structure of tensors (having
@@ -1168,10 +1249,29 @@ class _NestedDatasetComponent(object):
custom component types.
"""
- def __init__(self, dataset):
- self._output_classes = dataset.output_classes
- self._output_shapes = dataset.output_shapes
- self._output_types = dataset.output_types
+ def __init__(self,
+ dataset=None,
+ output_shapes=None,
+ output_types=None,
+ output_classes=None):
+ if dataset is None:
+ if (output_classes is None or output_shapes is None or
+ output_types is None):
+ raise ValueError(
+ "Either `dataset`, or all of `output_classes`, "
+ "`output_shapes`, and `output_types` must be specified.")
+ self._output_classes = output_classes
+ self._output_shapes = output_shapes
+ self._output_types = output_types
+ else:
+ if not (output_classes is None and output_shapes is None and
+ output_types is None):
+ raise ValueError(
+ "Either `dataset`, or all of `output_classes`, "
+ "`output_shapes`, and `output_types` must be specified.")
+ self._output_classes = dataset.output_classes
+ self._output_shapes = dataset.output_shapes
+ self._output_types = dataset.output_types
@property
def output_classes(self):
@@ -1330,7 +1430,11 @@ class StructuredFunctionWrapper(object):
flat_shapes.append(component)
flat_types.append(component)
else:
- t = ops.convert_to_tensor(t)
+ try:
+ t = ops.convert_to_tensor(t)
+ except (ValueError, TypeError):
+ raise TypeError("Unsupported return value from function passed to "
+ "%s: %s." % (transformation_name, t))
flat_ret.append(t)
flat_classes.append(ops.Tensor)
flat_shapes.append(t.get_shape())
@@ -1406,11 +1510,30 @@ def flat_structure(dataset):
A dictionary of keyword arguments that can be passed to many Dataset op
constructors.
"""
+ output_classes = []
+ output_shapes = []
+ output_types = []
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes),
+ nest.flatten(dataset.output_types)):
+ if isinstance(output_class, _NestedDatasetComponent):
+ output_classes.append(output_class.output_classes)
+ output_shapes.append(output_shape.output_shapes)
+ output_types.append(output_type.output_types)
+ else:
+ output_classes.append(output_class)
+ output_shapes.append(output_shape)
+ output_types.append(output_type)
+
+ output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes)
+ output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes)
+ output_types = nest.pack_sequence_as(dataset.output_types, output_types)
+
return {
- "output_shapes": nest.flatten(sparse.as_dense_shapes(
- dataset.output_shapes, dataset.output_classes)),
- "output_types": nest.flatten(sparse.as_dense_types(
- dataset.output_types, dataset.output_classes)),
+ "output_shapes":
+ nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)),
+ "output_types":
+ nest.flatten(sparse.as_dense_types(output_types, output_classes)),
}