aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/ops/dataset_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py56
1 files changed, 52 insertions, 4 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6205ee392e..c985e00dd1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1019,7 +1019,11 @@ class Dataset(object):
"""
return FlatMapDataset(self, map_func)
- def interleave(self, map_func, cycle_length, block_length=1):
+ def interleave(self,
+ map_func,
+ cycle_length,
+ block_length=1,
+ num_parallel_calls=None):
"""Maps `map_func` across this dataset, and interleaves the results.
For example, you can use `Dataset.interleave()` to process many input files
@@ -1082,11 +1086,19 @@ class Dataset(object):
processed concurrently.
block_length: The number of consecutive elements to produce from each
input element before cycling to another input element.
+ num_parallel_calls: (Optional.) If specified, the implementation creates
+ a threadpool, which is used to fetch inputs from cycle elements
+ asynchronously and in parallel. The default behavior is to fetch inputs
+ from cycle elements synchronously with no parallelism.
Returns:
Dataset: A `Dataset`.
"""
- return InterleaveDataset(self, map_func, cycle_length, block_length)
+ if num_parallel_calls is None:
+ return InterleaveDataset(self, map_func, cycle_length, block_length)
+ else:
+ return ParallelInterleaveDataset(self, map_func, cycle_length,
+ block_length, num_parallel_calls)
def filter(self, predicate):
"""Filters this dataset according to `predicate`.
@@ -2245,9 +2257,14 @@ class MapDataset(Dataset):
class ParallelMapDataset(MapDataset):
"""A `Dataset` that maps a function over elements in its input in parallel."""
- def __init__(self, input_dataset, map_func, num_parallel_calls):
+ def __init__(self,
+ input_dataset,
+ map_func,
+ num_parallel_calls,
+ use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(ParallelMapDataset, self).__init__(input_dataset, map_func)
+ super(ParallelMapDataset, self).__init__(input_dataset, map_func,
+ use_inter_op_parallelism)
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
@@ -2260,6 +2277,7 @@ class ParallelMapDataset(MapDataset):
self._map_func.captured_inputs,
f=self._map_func,
num_parallel_calls=self._num_parallel_calls,
+ use_inter_op_parallelism=self._use_inter_op_parallelism,
**flat_structure(self))
# pylint: enable=protected-access
@@ -2330,6 +2348,36 @@ class InterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
+class ParallelInterleaveDataset(FlatMapDataset):
+ """A `Dataset` that maps a function over its input and interleaves the result.
+
+ """
+
+ def __init__(self, input_dataset, map_func, cycle_length, block_length,
+ num_parallel_calls):
+ """See `Dataset.interleave()` for details."""
+ super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
+ self._cycle_length = ops.convert_to_tensor(
+ cycle_length, dtype=dtypes.int64, name="cycle_length")
+ self._block_length = ops.convert_to_tensor(
+ block_length, dtype=dtypes.int64, name="block_length")
+ self._num_parallel_calls = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parallel_interleave_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._map_func.captured_inputs, # pylint: disable=protected-access
+ self._cycle_length,
+ self._block_length,
+ self._num_parallel_calls,
+ f=self._map_func, # pylint: disable=protected-access
+ **flat_structure(self))
+
+ def _transformation_name(self):
+ return "Dataset.interleave()"
+
+
class FilterDataset(Dataset):
"""A `Dataset` that filters its input according to a predicate function."""