diff options
Diffstat (limited to 'tensorflow/python/data/ops/dataset_ops.py')
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 56 |
1 files changed, 52 insertions, 4 deletions
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6205ee392e..c985e00dd1 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1019,7 +1019,11 @@ class Dataset(object): """ return FlatMapDataset(self, map_func) - def interleave(self, map_func, cycle_length, block_length=1): + def interleave(self, + map_func, + cycle_length, + block_length=1, + num_parallel_calls=None): """Maps `map_func` across this dataset, and interleaves the results. For example, you can use `Dataset.interleave()` to process many input files @@ -1082,11 +1086,19 @@ class Dataset(object): processed concurrently. block_length: The number of consecutive elements to produce from each input element before cycling to another input element. + num_parallel_calls: (Optional.) If specified, the implementation creates + a threadpool, which is used to fetch inputs from cycle elements + asynchronously and in parallel. The default behavior is to fetch inputs + from cycle elements synchronously with no parallelism. Returns: Dataset: A `Dataset`. """ - return InterleaveDataset(self, map_func, cycle_length, block_length) + if num_parallel_calls is None: + return InterleaveDataset(self, map_func, cycle_length, block_length) + else: + return ParallelInterleaveDataset(self, map_func, cycle_length, + block_length, num_parallel_calls) def filter(self, predicate): """Filters this dataset according to `predicate`. @@ -2245,9 +2257,14 @@ class MapDataset(Dataset): class ParallelMapDataset(MapDataset): """A `Dataset` that maps a function over elements in its input in parallel.""" - def __init__(self, input_dataset, map_func, num_parallel_calls): + def __init__(self, + input_dataset, + map_func, + num_parallel_calls, + use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" - super(ParallelMapDataset, self).__init__(input_dataset, map_func) + super(ParallelMapDataset, self).__init__(input_dataset, map_func, + use_inter_op_parallelism) self._num_parallel_calls = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls") @@ -2260,6 +2277,7 @@ class ParallelMapDataset(MapDataset): self._map_func.captured_inputs, f=self._map_func, num_parallel_calls=self._num_parallel_calls, + use_inter_op_parallelism=self._use_inter_op_parallelism, **flat_structure(self)) # pylint: enable=protected-access @@ -2330,6 +2348,36 @@ class InterleaveDataset(FlatMapDataset): return "Dataset.interleave()" +class ParallelInterleaveDataset(FlatMapDataset): + """A `Dataset` that maps a function over its input and interleaves the result. + + """ + + def __init__(self, input_dataset, map_func, cycle_length, block_length, + num_parallel_calls): + """See `Dataset.interleave()` for details.""" + super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func) + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") + self._num_parallel_calls = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") + + def _as_variant_tensor(self): + return gen_dataset_ops.parallel_interleave_dataset_v2( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._map_func.captured_inputs, # pylint: disable=protected-access + self._cycle_length, + self._block_length, + self._num_parallel_calls, + f=self._map_func, # pylint: disable=protected-access + **flat_structure(self)) + + def _transformation_name(self): + return "Dataset.interleave()" + + class FilterDataset(Dataset): """A `Dataset` that filters its input according to a predicate function.""" |