aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py82
1 files changed, 57 insertions, 25 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index 87b7c6ddb7..e6883d53e0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.contrib.data.python.ops import get_single_element
+from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test.TestCase):
+class GetSingleElementTest(test.TestCase, parameterized.TestCase):
- def testGetSingleElement(self):
- skip_value = array_ops.placeholder(dtypes.int64, shape=[])
- take_value = array_ops.placeholder_with_default(
- constant_op.constant(1, dtype=dtypes.int64), shape=[])
+ @parameterized.named_parameters(
+ ("Zero", 0, 1),
+ ("Five", 5, 1),
+ ("Ten", 10, 1),
+ ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
+ ("MoreThanOne", 0, 2, errors.InvalidArgumentError,
+ "Dataset had more than one element."),
+ )
+ def testGetSingleElement(self, skip, take, error=None, error_msg=None):
+ skip_t = array_ops.placeholder(dtypes.int64, shape=[])
+ take_t = array_ops.placeholder(dtypes.int64, shape=[])
def make_sparse(x):
x_1d = array_ops.reshape(x, [1])
x_2d = array_ops.reshape(x, [1, 1])
return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
- dataset = (dataset_ops.Dataset.range(100)
- .skip(skip_value)
- .map(lambda x: (x * x, make_sparse(x)))
- .take(take_value))
-
+ dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
+ lambda x: (x * x, make_sparse(x))).take(take_t)
element = get_single_element.get_single_element(dataset)
with self.test_session() as sess:
- for x in [0, 5, 10]:
- dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x})
- self.assertEqual(x * x, dense_val)
- self.assertAllEqual([[x]], sparse_val.indices)
- self.assertAllEqual([x], sparse_val.values)
- self.assertAllEqual([x], sparse_val.dense_shape)
-
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "Dataset was empty."):
- sess.run(element, feed_dict={skip_value: 100})
-
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "Dataset had more than one element."):
- sess.run(element, feed_dict={skip_value: 0, take_value: 2})
+ if error is None:
+ dense_val, sparse_val = sess.run(
+ element, feed_dict={
+ skip_t: skip,
+ take_t: take
+ })
+ self.assertEqual(skip * skip, dense_val)
+ self.assertAllEqual([[skip]], sparse_val.indices)
+ self.assertAllEqual([skip], sparse_val.values)
+ self.assertAllEqual([skip], sparse_val.dense_shape)
+ else:
+ with self.assertRaisesRegexp(error, error_msg):
+ sess.run(element, feed_dict={skip_t: skip, take_t: take})
+
+ @parameterized.named_parameters(
+ ("SumZero", 0),
+ ("SumOne", 1),
+ ("SumFive", 5),
+ ("SumTen", 10),
+ )
+ def testReduceDataset(self, stop):
+ def init_fn(_):
+ return np.int64(0)
+
+ def reduce_fn(state, value):
+ return state + value
+
+ def finalize_fn(state):
+ return state
+
+ sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+
+ stop_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset_ops.Dataset.range(stop_t)
+ element = get_single_element.reduce_dataset(dataset, sum_reducer)
+
+ with self.test_session() as sess:
+ value = sess.run(element, feed_dict={stop_t: stop})
+ self.assertEqual(stop * (stop - 1) / 2, value)
if __name__ == "__main__":