diff options
author | 2018-09-06 16:09:15 -0700 | |
---|---|---|
committer | 2018-09-06 16:17:43 -0700 | |
commit | 9a6ab2af59f3b21ffa2b74093ccc9af4edaf7f98 (patch) | |
tree | 748882485661e750cf19f1d9ca182a590bbb7c8b /tensorflow/contrib/data | |
parent | 33d2a0e7064cd14540121e38457d4a81aa57a650 (diff) |
[tf.data] Adding support for `num_parallel_calls` to `tf.data.Dataset.interleave`.
Unlike the `tf.data.contrib.parallel_interleave` whose parallelism is tied to the `cycle_length` argument, the newly introduced `num_parallel_calls` argument of `tf.data.Dataset.interleave` is decoupled from the `cycle_length` argument and identifies the degree of parallelism to use for fetching output elements.
PiperOrigin-RevId: 211886816
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/serialization/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py | 45 |
2 files changed, 22 insertions, 24 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 4881f63ab9..aa89674c6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -210,6 +210,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py index ac3892fe81..243f6405a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base @@ -27,42 +28,38 @@ from tensorflow.python.platform import test class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): - def _build_iterator_graph(self, input_values, cycle_length, block_length): + def _build_iterator_graph(self, input_values, cycle_length, block_length, + num_parallel_calls): repeat_count = 2 return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( repeat_count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) + cycle_length, block_length, num_parallel_calls) - def testSerializationCore(self): + @parameterized.named_parameters( + ("1", 2, 3, None), + ("2", 2, 3, 1), + ("3", 2, 3, 2), + ("4", 1, 3, None), + ("5", 1, 3, 1), + ("6", 2, 1, None), + ("7", 2, 1, 1), + ("8", 2, 1, 2), + ) + def testSerializationCore(self, cycle_length, block_length, + num_parallel_calls): input_values = np.array([4, 5, 6], dtype=np.int64) num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 # pylint: disable=g-long-lambda self.run_core_tests( lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), + input_values, cycle_length, block_length, num_parallel_calls), lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), + input_values, cycle_length * 2, block_length, num_parallel_calls), num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) # pylint: enable=g-long-lambda def testSparseCore(self): @@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) -if __name__ == '__main__': +if __name__ == "__main__": test.main() |