aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-06 13:08:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 13:15:28 -0700
commita8a22af204ef4ddb4ada55c17863dbd286b90b30 (patch)
tree103632aa68cc87fa545a4246f4f228b754036bd5 /tensorflow/contrib/data
parent612166a4f4c79efbe9e34e75652e10300150ec7a (diff)
[tf.data] Naming parameterized tests to facilitate invoking them individually and using consistent style for existing test names.
PiperOrigin-RevId: 211855926
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py22
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py56
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py144
5 files changed, 149 insertions, 131 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 9d8e955245..67242fecfe 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
@parameterized.named_parameters(
- ("default", None, None),
- ("sequential_calls", 1, None),
- ("parallel_calls", 2, None),
- ("parallel_batches", None, 10),
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
)
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
@@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
- ("even", False),
- ("uneven", True),
+ ("Even", False),
+ ("Uneven", True),
)
def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
@@ -663,7 +663,14 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
for _ in range(3):
sess.run(get_next)
- @parameterized.parameters(0, 5, 10, 90, 95, 99)
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
def testMapAndBatchOutOfRangeError(self, threshold):
def raising_py_fn(i):
@@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (False, dtypes.bool),
- (-42, dtypes.int8),
- (-42, dtypes.int16),
- (-42, dtypes.int32),
- (-42, dtypes.int64),
- (42, dtypes.uint8),
- (42, dtypes.uint16),
- (42.0, dtypes.float16),
- (42.0, dtypes.float32),
- (42.0, dtypes.float64),
- (b"hello", dtypes.string),
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
)
def testMapAndBatchTypes(self, element, dtype):
def gen():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index 586b4bee5f..6a7ef877f9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -44,22 +44,22 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for i, fun1 in enumerate(functions):
for j, fun2 in enumerate(functions):
tests.append((
- "test_{}_{}".format(i, j),
+ "Test{}{}".format(i, j),
[fun1, fun2],
))
for k, fun3 in enumerate(functions):
tests.append((
- "test_{}_{}_{}".format(i, j, k),
+ "Test{}{}{}".format(i, j, k),
[fun1, fun2, fun3],
))
swap = lambda x, n: (n, x)
tests.append((
- "swap1",
+ "Swap1",
[lambda x: (x, 42), swap],
))
tests.append((
- "swap2",
+ "Swap2",
[lambda x: (x, 42), swap, swap],
))
return tuple(tests)
@@ -109,13 +109,13 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for x, fun in enumerate(functions):
for y, predicate in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), fun, predicate))
+ tests.append(("Mixed{}{}".format(x, y), fun, predicate))
# Multi output
- tests.append(("multiOne", lambda x: (x, x),
+ tests.append(("Multi1", lambda x: (x, x),
lambda x, y: constant_op.constant(True)))
tests.append(
- ("multiTwo", lambda x: (x, 2),
+ ("Multi2", lambda x: (x, 2),
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
return tuple(tests)
@@ -172,17 +172,17 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
identity = lambda x: x
for x, predicate_1 in enumerate(filters):
for y, predicate_2 in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), identity,
+ tests.append(("Mixed{}{}".format(x, y), identity,
[predicate_1, predicate_2]))
for z, predicate_3 in enumerate(filters):
- tests.append(("mixed_{}_{}_{}".format(x, y, z), identity,
+ tests.append(("Mixed{}{}{}".format(x, y, z), identity,
[predicate_1, predicate_2, predicate_3]))
take_all_multiple = lambda x, y: constant_op.constant(True)
# Multi output
- tests.append(("multiOne", lambda x: (x, x),
+ tests.append(("Multi1", lambda x: (x, x),
[take_all_multiple, take_all_multiple]))
- tests.append(("multiTwo", lambda x: (x, 2), [
+ tests.append(("Multi2", lambda x: (x, 2), [
take_all_multiple,
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
]))
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 8b2f846494..6b3e8e9f6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -32,18 +32,18 @@ from tensorflow.python.platform import test
class SlideDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDataset(self, count, window_size, window_shift, window_stride):
"""Tests a dataset that slides a window its input elements."""
@@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDatasetDeprecated(self, count, window_size, stride,
window_stride):
@@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (14, 0, 3, 1),
- (14, 3, 0, 1),
- (14, 3, 3, 0),
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
)
def testSlideDatasetInvalid(self, count, window_size, window_shift,
window_stride):
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 0486e2bce2..4b08ec759d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -33,8 +33,17 @@ from tensorflow.python.platform import test
class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
- (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def get_thread_id(_):
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 33d95d6754..ff4d9b3260 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual(xs, ys)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetFlatMap(self, structure, shape, dtype):
"""Tests windowing by chaining it with flat map.
@@ -97,15 +97,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchDense(self, structure, shape, dtype):
"""Tests batching of dense tensor windows.
@@ -135,10 +135,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchDenseDynamicShape(self, shape):
"""Tests batching of dynamically shaped dense tensor windows.
@@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchSparse(self, structure, shape, dtype):
"""Tests batching of sparse tensor windows.
@@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchSparseDynamicShape(self, shape):
"""Tests batching of dynamically shaped sparse tensor windows.
@@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
]))
- @parameterized.parameters(
- (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
)
def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
padded_shape):
@@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1], [2], [3]]), [-1]),
- (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1], [2], [3]]), [-1]),
+ ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
"""Tests padded batching of dynamically shaped dense tensor windows.
@@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1]]), np.int32([0])),
- (np.int32([[10], [20]]), np.int32([15])),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1]]), np.int32([0])),
+ ("2", np.int32([[10], [20]]), np.int32([15])),
)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of dense tensor windows.
@@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
)
def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
padded_shape):
@@ -463,10 +465,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1], [2], [3]]), [-1]),
- (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1], [2], [3]]), [-1]),
+ ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
padded_shape):
@@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1]]), [0]),
- (np.int64([[10], [20]]), [15]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1]]), [0]),
+ ("2", np.int64([[10], [20]]), [15]),
)
def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of sparse tensor windows.