aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py255
1 files changed, 156 insertions, 99 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 5590a4bf78..8b2f846494 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import sliding
@@ -29,28 +30,45 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class SlideDatasetTest(test.TestCase):
-
- def testSlideDataset(self):
- """Test an dataset that maps a TF function across its input elements."""
+class SlideDatasetTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(
+ (20, 14, 7, 1),
+ (20, 17, 9, 1),
+ (20, 14, 14, 1),
+ (20, 10, 14, 1),
+ (20, 14, 19, 1),
+ (20, 4, 1, 2),
+ (20, 2, 1, 6),
+ (20, 4, 7, 2),
+ (20, 2, 7, 6),
+ (1, 10, 4, 1),
+ (0, 10, 4, 1),
+ )
+ def testSlideDataset(self, count, window_size, window_shift, window_stride):
+ """Tests a dataset that slides a window its input elements."""
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7))
- count = array_ops.placeholder(dtypes.int64, shape=[])
- window_size = array_ops.placeholder(dtypes.int64, shape=[])
- stride = array_ops.placeholder(dtypes.int64, shape=[])
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_stride_t = array_ops.placeholder(dtypes.int64, shape=[])
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
- # RepeatDataset(count) -> _SlideDataset(window_size, stride).
- iterator = (dataset_ops.Dataset.from_tensor_slices(components)
- .map(_map_fn)
- .repeat(count)
- .apply(sliding.sliding_window_batch(window_size, stride))
- .make_initializable_iterator())
+ # RepeatDataset(count) ->
+ # _SlideDataset(window_size, window_shift, window_stride).
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(count).apply(
+ sliding.sliding_window_batch(
+ window_size=window_size_t,
+ window_shift=window_shift_t,
+ window_stride=window_stride_t)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -58,90 +76,126 @@ class SlideDatasetTest(test.TestCase):
[t.shape.as_list() for t in get_next])
with self.test_session() as sess:
- # stride < window_size.
- # Slide over a finite input, where the window_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
- # Same formula with convolution layer.
- num_batches = (20 * 7 - 14) // 7 + 1
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i*7 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- # Slide over a finite input, where the window_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})
- num_batches = (20 * 7 - 17) // 9 + 1
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ window_size_t: window_size,
+ window_shift_t: window_shift,
+ window_stride_t: window_stride
+ })
+ num_batches = (count * 7 - (
+ (window_size - 1) * window_stride + 1)) // window_shift + 1
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
- for j in range(17):
- self.assertAllEqual(component[(i*9 + j) % 7]**2,
- result_component[j])
+ for j in range(window_size):
+ self.assertAllEqual(
+ component[(i * window_shift + j * window_stride) % 7]**2,
+ result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # stride == window_size.
- sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14})
- num_batches = 20 * 7 // 14
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i*14 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ @parameterized.parameters(
+ (20, 14, 7, 1),
+ (20, 17, 9, 1),
+ (20, 14, 14, 1),
+ (20, 10, 14, 1),
+ (20, 14, 19, 1),
+ (20, 4, 1, 2),
+ (20, 2, 1, 6),
+ (20, 4, 7, 2),
+ (20, 2, 7, 6),
+ (1, 10, 4, 1),
+ (0, 10, 4, 1),
+ )
+ def testSlideDatasetDeprecated(self, count, window_size, stride,
+ window_stride):
+ """Tests a dataset that slides a window its input elements."""
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
- # stride > window_size.
- sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14})
- num_batches = 20 * 7 // 14
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(10):
- self.assertAllEqual(component[(i*14 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- # Drop the last batch which is smaller than window_size.
- sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19})
- num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i*19 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_stride_t = array_ops.placeholder(dtypes.int64, shape=[])
- # Slide over a finite input, which is less than window_size,
- # should fail straight away.
- sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
+ # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride).
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(count).apply(
+ sliding.sliding_window_batch(
+ window_size=window_size_t,
+ stride=stride_t,
+ window_stride=window_stride_t)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
- # Slide over an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4})
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.test_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ window_size_t: window_size,
+ stride_t: stride,
+ window_stride_t: window_stride
+ })
+ num_batches = (count * 7 - (
+ (window_size - 1) * window_stride + 1)) // stride + 1
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(window_size):
+ self.assertAllEqual(
+ component[(i * stride + j * window_stride) % 7]**2,
+ result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # Empty window_size should be an initialization time error.
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0})
+ @parameterized.parameters(
+ (14, 0, 3, 1),
+ (14, 3, 0, 1),
+ (14, 3, 3, 0),
+ )
+ def testSlideDatasetInvalid(self, count, window_size, window_shift,
+ window_stride):
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ window_stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+
+ iterator = (
+ dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply(
+ sliding.sliding_window_batch(
+ window_size=window_size_t,
+ window_shift=window_shift_t,
+ window_stride=window_stride_t)).make_initializable_iterator())
+ init_op = iterator.initializer
- # Invalid stride should be an initialization time error.
+ with self.test_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ window_size_t: window_size,
+ window_shift_t: window_shift,
+ window_stride_t: window_stride
+ })
+
+ def testSlideDatasetValueError(self):
+ with self.assertRaises(ValueError):
+ dataset_ops.Dataset.range(10).map(lambda x: x).apply(
+ sliding.sliding_window_batch(
+ window_size=1, stride=1, window_shift=1, window_stride=1))
def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
@@ -155,7 +209,8 @@ class SlideDatasetTest(test.TestCase):
indices=[[0]], values=(i * [1]), dense_shape=[1])
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
- sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
+ sliding.sliding_window_batch(
+ window_size=5, window_shift=3)).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -183,7 +238,8 @@ class SlideDatasetTest(test.TestCase):
dense_shape=[i])
iterator = dataset_ops.Dataset.range(10).map(_sparse).apply(
- sliding.sliding_window_batch(5, 3)).make_initializable_iterator()
+ sliding.sliding_window_batch(
+ window_size=5, window_shift=3)).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -213,11 +269,11 @@ class SlideDatasetTest(test.TestCase):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])
- iterator = (dataset_ops.Dataset.range(10)
- .map(_sparse)
- .apply(sliding.sliding_window_batch(4, 2))
- .apply(sliding.sliding_window_batch(3, 1))
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.range(10).map(_sparse).apply(
+ sliding.sliding_window_batch(window_size=4, window_shift=2)).apply(
+ sliding.sliding_window_batch(window_size=3, window_shift=1))
+ .make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -226,9 +282,9 @@ class SlideDatasetTest(test.TestCase):
# Slide: 1st batch.
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
- [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
- [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
dense_shape=[3, 4, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
@@ -236,9 +292,9 @@ class SlideDatasetTest(test.TestCase):
# Slide: 2nd batch.
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
- [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0],
- [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]],
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
dense_shape=[3, 4, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
@@ -253,10 +309,11 @@ class SlideDatasetTest(test.TestCase):
yield [4.0, 5.0, 6.0]
yield [7.0, 8.0, 9.0, 10.0]
- iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32,
- output_shapes=[None])
- .apply(sliding.sliding_window_batch(3, 1))
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).apply(
+ sliding.sliding_window_batch(window_size=3, window_shift=1))
+ .make_initializable_iterator())
next_element = iterator.get_next()
with self.test_session() as sess: