diff options
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py | 77 |
1 files changed, 26 insertions, 51 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 07fecf04fa..df9147af6c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -32,7 +32,7 @@ from tensorflow.python.util import nest class DatasetSerializationTestBase(test.TestCase): - """Base class for testing serializable datasets.""" + """Base class for testing finite serializable datasets.""" def tearDown(self): self._delete_ckpt() @@ -58,19 +58,17 @@ class DatasetSerializationTestBase(test.TestCase): if ds_fn2: self.verify_restore_in_modified_graph(ds_fn1, ds_fn2, num_outputs) - def verify_unused_iterator(self, ds_fn, num_outputs, verify_exhausted=True): + def verify_unused_iterator(self, ds_fn, num_outputs): """Verifies that saving and restoring an unused iterator works. Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. - verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. """ - self.verify_run_with_breaks( - ds_fn, [0], num_outputs, verify_exhausted=verify_exhausted) + self.verify_run_with_breaks(ds_fn, [0], num_outputs) def verify_fully_used_iterator(self, ds_fn, num_outputs): """Verifies that saving and restoring a fully used iterator works. @@ -106,16 +104,12 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, [], 0, ckpt_saved=True, verify_exhausted=True) self.assertEqual(len(actual), 0) - def verify_init_before_restore(self, - ds_fn, - num_outputs, - verify_exhausted=True): + def verify_init_before_restore(self, ds_fn, num_outputs): """Verifies that retoring into an already initilized iterator works. Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. - verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -124,14 +118,9 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, self.gen_break_points(num_outputs), num_outputs, - init_before_restore=True, - verify_exhausted=verify_exhausted) + init_before_restore=True) - def verify_multiple_breaks(self, - ds_fn, - num_outputs, - num_breaks=10, - verify_exhausted=True): + def verify_multiple_breaks(self, ds_fn, num_outputs, num_breaks=10): """Attempts to save/restore at multiple break points. Args: @@ -139,22 +128,16 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs: See `run_core_tests`. num_breaks: The number of break points. These are uniformly spread in [0, num_outputs] both inclusive. - verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. """ - self.verify_run_with_breaks( - ds_fn, - self.gen_break_points(num_outputs), - num_outputs, - verify_exhausted=verify_exhausted) + self.verify_run_with_breaks(ds_fn, + self.gen_break_points(num_outputs, num_breaks), + num_outputs) - def verify_reset_restored_iterator(self, - ds_fn, - num_outputs, - break_point=None, - verify_exhausted=True): + def verify_reset_restored_iterator(self, ds_fn, num_outputs, + break_point=None): """Attempts to re-initialize a restored iterator. This is useful when restoring a training checkpoint during validation. @@ -163,7 +146,6 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. break_point: Break point. Optional. Defaults to num_outputs/2. - verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -171,8 +153,7 @@ class DatasetSerializationTestBase(test.TestCase): break_point = num_outputs // 2 if not break_point else break_point # Collect ground truth containing all outputs. - expected = self.gen_outputs( - ds_fn, [], num_outputs, verify_exhausted=verify_exhausted) + expected = self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True) # Skip some items and save checkpoint. self.gen_outputs(ds_fn, [], break_point, verify_exhausted=False) @@ -187,17 +168,15 @@ class DatasetSerializationTestBase(test.TestCase): sess.run(init_op) for _ in range(num_outputs): actual.append(sess.run(get_next_op)) - if verify_exhausted: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) self.match(expected, actual) def verify_restore_in_modified_graph(self, ds_fn1, ds_fn2, num_outputs, - break_point=None, - verify_exhausted=True): + break_point=None): """Attempts to restore an iterator in a modified graph. Builds an input pipeline using ds_fn1, runs it for `break_point` steps @@ -209,7 +188,6 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn2: See `run_core_tests`. num_outputs: See `run_core_tests`. break_point: Break point. Optional. Defaults to num_outputs/2. - verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -218,15 +196,15 @@ class DatasetSerializationTestBase(test.TestCase): # Skip `break_point` items and store the remaining produced from ds_fn1 # in `expected`. - self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False) + self.gen_outputs(ds_fn1, [], break_point) expected = self.gen_outputs( ds_fn1, [], num_outputs - break_point, ckpt_saved=True, - verify_exhausted=verify_exhausted) + verify_exhausted=True) # Generate `break_point` items from ds_fn1 and save checkpoint. - self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False) + self.gen_outputs(ds_fn1, [], break_point) actual = [] # Build graph for ds_fn2 but load checkpoint for ds_fn1. @@ -236,9 +214,8 @@ class DatasetSerializationTestBase(test.TestCase): self._restore(saver, sess) for _ in range(num_outputs - break_point): actual.append(sess.run(get_next_op)) - if verify_exhausted: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) self.match(expected, actual) @@ -246,7 +223,6 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, break_points, num_outputs, - verify_exhausted=True, init_before_restore=False): """Verifies that ds_fn() produces the same outputs with and without breaks. @@ -261,7 +237,6 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn: See `gen_outputs`. break_points: See `gen_outputs`. num_outputs: See `gen_outputs`. - verify_exhausted: See `gen_outputs`. init_before_restore: See `gen_outputs`. Raises: @@ -270,13 +245,13 @@ class DatasetSerializationTestBase(test.TestCase): expected = self.gen_outputs( ds_fn, [], num_outputs, - verify_exhausted=verify_exhausted, + verify_exhausted=True, init_before_restore=init_before_restore) actual = self.gen_outputs( ds_fn, break_points, num_outputs, - verify_exhausted=verify_exhausted, + verify_exhausted=True, init_before_restore=init_before_restore) self.match(expected, actual) @@ -286,7 +261,7 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs, ckpt_saved=False, init_before_restore=False, - verify_exhausted=True): + verify_exhausted=False): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -310,7 +285,7 @@ class DatasetSerializationTestBase(test.TestCase): after producing `num_outputs` elements. Returns: - A list of `num_outputs` items. + A list if `num_outputs` items. """ outputs = [] @@ -337,11 +312,11 @@ class DatasetSerializationTestBase(test.TestCase): num_iters = end - start for _ in range(num_iters): outputs.append(sess.run(get_next_op)) + self._save(sess, saver) + ckpt_saved = True if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) - self._save(sess, saver) - ckpt_saved = True return outputs |