diff options
author | 2018-09-28 15:22:06 -0700 | |
---|---|---|
committer | 2018-09-28 15:25:06 -0700 | |
commit | 1c4a48ddd49f78fbd8ea3defd3a8755c91284166 (patch) | |
tree | 32c6e79a58dd3b6dc62c4eb225651c6fb003d16a /tensorflow/python/data | |
parent | 2f559f2d5f75cf80183ae0d855110809404019f7 (diff) |
[tf.data] Merged contrib.data's DatasetTestBase with the DatasetTestBase in core (and added that as a base class for all the contrib tests). Also changed the assertDatasetsEqual functions so they are both graph and eager compatible (took the code from CSVDatasetTest) :)
PiperOrigin-RevId: 215004892
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/test_base.py | 80 |
2 files changed, 83 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 5f9818566f..cadfe7f9e0 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -471,6 +471,9 @@ py_library( srcs = ["test_base.py"], deps = [ "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/util:nest", ], ) diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index b4f64115b7..b730e10949 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -17,6 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + +from tensorflow.python.data.util import nest +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import test @@ -24,6 +30,80 @@ class DatasetTestBase(test.TestCase): """Base class for dataset tests.""" def assertSparseValuesEqual(self, a, b): + """Asserts that two SparseTensors/SparseTensorValues are equal.""" self.assertAllEqual(a.indices, b.indices) self.assertAllEqual(a.values, b.values) self.assertAllEqual(a.dense_shape, b.dense_shape) + + def getNext(self, dataset): + """Returns a callable that returns the next element of the dataset. + + Example use: + ```python + # In both graph and eager modes + dataset = ... + nxt = self.getNext(dataset) + result = self.evaluate(nxt()) + ``` + + Args: + dataset: A dataset whose next element is returned + + Returns: + A callable that returns the next element of `dataset` + """ + it = dataset.make_one_shot_iterator() + if context.executing_eagerly(): + return it.get_next + else: + nxt = it.get_next() + return lambda: nxt + + def assertDatasetsEqual(self, dataset1, dataset2): + """Checks that datasets are equal. Supports both graph and eager mode.""" + self.assertEqual(dataset1.output_types, dataset2.output_types) + self.assertEqual(dataset1.output_classes, dataset2.output_classes) + + next1 = self.getNext(dataset1) + next2 = self.getNext(dataset2) + while True: + try: + op1 = self.evaluate(next1()) + except errors.OutOfRangeError: + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(next2()) + break + op2 = self.evaluate(next2()) + + op1 = nest.flatten(op1) + op2 = nest.flatten(op2) + assert len(op1) == len(op2) + for i in range(len(op1)): + if isinstance( + op1[i], + (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + self.assertSparseValuesEqual(op1[i], op2[i]) + else: + self.assertAllEqual(op1[i], op2[i]) + + def assertDatasetsRaiseSameError(self, + dataset1, + dataset2, + exception_class, + replacements=None): + """Checks that datasets raise the same error on the first get_next call.""" + next1 = self.getNext(dataset1) + next2 = self.getNext(dataset2) + try: + self.evaluate(next1()) + raise ValueError( + 'Expected dataset to raise an error of type %s, but it did not.' % + repr(exception_class)) + except exception_class as e: + expected_message = e.message + for old, new, count in replacements: + expected_message = expected_message.replace(old, new, count) + # Check that the first segment of the error messages are the same. + with self.assertRaisesRegexp(exception_class, + re.escape(expected_message)): + self.evaluate(next2()) |