aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 16:08:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 16:08:25 -0700
commite4e036fef21b933171ce382c36f6c730f8322219 (patch)
treefc27ab76eb1e43f2f84118625b2e39682fc87b85 /tensorflow/python/data
parentadb742eba146478c3cee86d7b366e3faf121f6bd (diff)
parentadb5d74f52917d00e9a779a74f0e0a4e5ca22ca4 (diff)
Merge pull request #22170 from Smokrow:patch-1
PiperOrigin-RevId: 214058098
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 93b3a7b93b..7c20c049f5 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1007,8 +1007,25 @@ class Dataset(object):
return ParallelMapDataset(self, map_func, num_parallel_calls)
def flat_map(self, map_func):
- """Maps `map_func` across this dataset and flattens the result.
+ """Maps `map_func` across this dataset and flattens the result.
+
+ Use `flat_map` if you want to make sure that the order of your dataset
+ stays the same. For example, to flatten a dataset of batches into a
+ dataset of their elements:
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset. '[...]' represents a tensor.
+ a = {[1,2,3,4,5], [6,7,8,9], [10]}
+
+ a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
+ {[1,2,3,4,5,6,7,8,9,10]}
+ ```
+
+ `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
+ `flat_map` produces the same output as
+ `tf.data.Dataset.interleave(cycle_length=1)`
+
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1043,7 +1060,7 @@ class Dataset(object):
elements are produced. `cycle_length` controls the number of input elements
that are processed concurrently. If you set `cycle_length` to 1, this
transformation will handle one input element at a time, and will produce
- identical results = to `tf.data.Dataset.flat_map`. In general,
+ identical results to `tf.data.Dataset.flat_map`. In general,
this transformation will apply `map_func` to `cycle_length` input elements,
open iterators on the returned `Dataset` objects, and cycle through them
producing `block_length` consecutive elements from each iterator, and