aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py167
1 files changed, 78 insertions, 89 deletions
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index 7dbf7268d7..a35cee594a 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -19,8 +19,10 @@ from __future__ import print_function
import itertools
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -28,7 +30,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase):
+class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
@@ -97,84 +99,85 @@ class InterleaveDatasetTest(test.TestCase):
expected_elements, self._interleave(input_lists, 7, 2)):
self.assertEqual(expected, produced)
- def testInterleaveDataset(self):
- input_values = array_ops.placeholder(dtypes.int64, shape=[None])
- cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
- block_length = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_count = 2
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_values)
- .repeat(repeat_count)
- .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
+ @parameterized.named_parameters(
+ ("1", np.int64([4, 5, 6]), 1, 3, None),
+ ("2", np.int64([4, 5, 6]), 1, 3, 1),
+ ("3", np.int64([4, 5, 6]), 2, 1, None),
+ ("4", np.int64([4, 5, 6]), 2, 1, 1),
+ ("5", np.int64([4, 5, 6]), 2, 1, 2),
+ ("6", np.int64([4, 5, 6]), 2, 3, None),
+ ("7", np.int64([4, 5, 6]), 2, 3, 1),
+ ("8", np.int64([4, 5, 6]), 2, 3, 2),
+ ("9", np.int64([4, 5, 6]), 7, 2, None),
+ ("10", np.int64([4, 5, 6]), 7, 2, 1),
+ ("11", np.int64([4, 5, 6]), 7, 2, 3),
+ ("12", np.int64([4, 5, 6]), 7, 2, 5),
+ ("13", np.int64([4, 5, 6]), 7, 2, 7),
+ ("14", np.int64([]), 2, 3, None),
+ ("15", np.int64([0, 0, 0]), 2, 3, None),
+ ("16", np.int64([4, 0, 6]), 2, 3, None),
+ ("17", np.int64([4, 0, 6]), 2, 3, 1),
+ ("18", np.int64([4, 0, 6]), 2, 3, 2),
+ )
+ def testInterleaveDataset(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
+ count = 2
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+ count).interleave(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+ cycle_length, block_length, num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ def repeat(values, count):
+ result = []
+ for value in values:
+ result.append([value] * value)
+ return result * count
with self.test_session() as sess:
- # Cycle length 1 acts like `Dataset.flat_map()`.
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 1, block_length: 3})
-
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3):
- self.assertEqual(expected_element, sess.run(next_element))
-
- # Cycle length > 1.
- # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
- # 6, 5, 6, 5, 6, 5, 6, 5]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 1})
for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > 1 and block length > 1.
- # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5,
- # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > len(input_values) * repeat_count.
- # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4,
- # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 7, block_length: 2})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Empty input.
- sess.run(init_op, feed_dict={input_values: [],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ repeat(input_values, count), cycle_length, block_length):
+ self.assertEqual(expected_element, sess.run(get_next))
+
+ for _ in range(2):
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None),
+ ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1),
+ ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None),
+ ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1),
+ ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2),
+ ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None),
+ ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1),
+ ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2),
+ ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None),
+ ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1),
+ ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3),
+ ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5),
+ ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7),
+ )
+ def testInterleaveErrorDataset(self,
+ input_values,
+ cycle_length,
+ block_length,
+ num_parallel_calls):
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).interleave(
+ dataset_ops.Dataset.from_tensors, cycle_length, block_length,
+ num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
- # Non-empty input leading to empty output.
- sess.run(init_op, feed_dict={input_values: [0, 0, 0],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Mixture of non-empty and empty interleaved datasets.
- sess.run(init_op, feed_dict={input_values: [4, 0, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [], [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
+ with self.test_session() as sess:
+ for value in input_values:
+ if np.isnan(value):
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+ else:
+ self.assertEqual(value, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ sess.run(get_next)
def testSparse(self):
@@ -201,20 +204,6 @@ class InterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testEmptyInput(self):
- iterator = (
- dataset_ops.Dataset.from_tensor_slices([])
- .repeat(None)
- .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
if __name__ == "__main__":
test.main()