aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-06 16:09:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 16:17:43 -0700
commit9a6ab2af59f3b21ffa2b74093ccc9af4edaf7f98 (patch)
tree748882485661e750cf19f1d9ca182a590bbb7c8b /tensorflow/contrib/data
parent33d2a0e7064cd14540121e38457d4a81aa57a650 (diff)
[tf.data] Adding support for `num_parallel_calls` to `tf.data.Dataset.interleave`.
Unlike the `tf.data.contrib.parallel_interleave` whose parallelism is tied to the `cycle_length` argument, the newly introduced `num_parallel_calls` argument of `tf.data.Dataset.interleave` is decoupled from the `cycle_length` argument and identifies the degree of parallelism to use for fetching output elements. PiperOrigin-RevId: 211886816
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py45
2 files changed, 22 insertions, 24 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 4881f63ab9..aa89674c6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -210,6 +210,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
index ac3892fe81..243f6405a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
@@ -27,42 +28,38 @@ from tensorflow.python.platform import test
class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
- def _build_iterator_graph(self, input_values, cycle_length, block_length):
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length)
+ cycle_length, block_length, num_parallel_calls)
- def testSerializationCore(self):
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
+ input_values, cycle_length, block_length, num_parallel_calls),
lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length * 1),
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
# pylint: enable=g-long-lambda
def testSparseCore(self):
@@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest(
self.run_core_tests(_build_dataset, None, 20)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()