diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py')
-rw-r--r-- | tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py | 167 |
1 files changed, 78 insertions, 89 deletions
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py index 7dbf7268d7..a35cee594a 100644 --- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py @@ -19,8 +19,10 @@ from __future__ import print_function import itertools +from absl.testing import parameterized +import numpy as np + from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -28,7 +30,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetTest(test.TestCase): +class InterleaveDatasetTest(test.TestCase, parameterized.TestCase): def _interleave(self, lists, cycle_length, block_length): num_open = 0 @@ -97,84 +99,85 @@ class InterleaveDatasetTest(test.TestCase): expected_elements, self._interleave(input_lists, 7, 2)): self.assertEqual(expected, produced) - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - next_element = iterator.get_next() + @parameterized.named_parameters( + ("1", np.int64([4, 5, 6]), 1, 3, None), + ("2", np.int64([4, 5, 6]), 1, 3, 1), + ("3", np.int64([4, 5, 6]), 2, 1, None), + ("4", np.int64([4, 5, 6]), 2, 1, 1), + ("5", np.int64([4, 5, 6]), 2, 1, 2), + ("6", np.int64([4, 5, 6]), 2, 3, None), + ("7", np.int64([4, 5, 6]), 2, 3, 1), + ("8", np.int64([4, 5, 6]), 2, 3, 2), + ("9", np.int64([4, 5, 6]), 7, 2, None), + ("10", np.int64([4, 5, 6]), 7, 2, 1), + ("11", np.int64([4, 5, 6]), 7, 2, 3), + ("12", np.int64([4, 5, 6]), 7, 2, 5), + ("13", np.int64([4, 5, 6]), 7, 2, 7), + ("14", np.int64([]), 2, 3, None), + ("15", np.int64([0, 0, 0]), 2, 3, None), + ("16", np.int64([4, 0, 6]), 2, 3, None), + ("17", np.int64([4, 0, 6]), 2, 3, 1), + ("18", np.int64([4, 0, 6]), 2, 3, 2), + ) + def testInterleaveDataset(self, input_values, cycle_length, block_length, + num_parallel_calls): + count = 2 + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length, num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() + + def repeat(values, count): + result = [] + for value in values: + result.append([value] * value) + return result * count with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + repeat(input_values, count), cycle_length, block_length): + self.assertEqual(expected_element, sess.run(get_next)) + + for _ in range(2): + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @parameterized.named_parameters( + ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None), + ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1), + ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None), + ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1), + ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2), + ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None), + ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1), + ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2), + ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None), + ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1), + ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3), + ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5), + ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7), + ) + def testInterleaveErrorDataset(self, + input_values, + cycle_length, + block_length, + num_parallel_calls): + dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( + lambda x: array_ops.check_numerics(x, "message")).interleave( + dataset_ops.Dataset.from_tensors, cycle_length, block_length, + num_parallel_calls) + get_next = dataset.make_one_shot_iterator().get_next() - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) + with self.test_session() as sess: + for value in input_values: + if np.isnan(value): + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + else: + self.assertEqual(value, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + sess.run(get_next) def testSparse(self): @@ -201,20 +204,6 @@ class InterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testEmptyInput(self): - iterator = ( - dataset_ops.Dataset.from_tensor_slices([]) - .repeat(None) - .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - if __name__ == "__main__": test.main() |