aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py77
1 files changed, 26 insertions, 51 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index 07fecf04fa..df9147af6c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -32,7 +32,7 @@ from tensorflow.python.util import nest
class DatasetSerializationTestBase(test.TestCase):
- """Base class for testing serializable datasets."""
+ """Base class for testing finite serializable datasets."""
def tearDown(self):
self._delete_ckpt()
@@ -58,19 +58,17 @@ class DatasetSerializationTestBase(test.TestCase):
if ds_fn2:
self.verify_restore_in_modified_graph(ds_fn1, ds_fn2, num_outputs)
- def verify_unused_iterator(self, ds_fn, num_outputs, verify_exhausted=True):
+ def verify_unused_iterator(self, ds_fn, num_outputs):
"""Verifies that saving and restoring an unused iterator works.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
- self.verify_run_with_breaks(
- ds_fn, [0], num_outputs, verify_exhausted=verify_exhausted)
+ self.verify_run_with_breaks(ds_fn, [0], num_outputs)
def verify_fully_used_iterator(self, ds_fn, num_outputs):
"""Verifies that saving and restoring a fully used iterator works.
@@ -106,16 +104,12 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn, [], 0, ckpt_saved=True, verify_exhausted=True)
self.assertEqual(len(actual), 0)
- def verify_init_before_restore(self,
- ds_fn,
- num_outputs,
- verify_exhausted=True):
+ def verify_init_before_restore(self, ds_fn, num_outputs):
"""Verifies that retoring into an already initilized iterator works.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
@@ -124,14 +118,9 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn,
self.gen_break_points(num_outputs),
num_outputs,
- init_before_restore=True,
- verify_exhausted=verify_exhausted)
+ init_before_restore=True)
- def verify_multiple_breaks(self,
- ds_fn,
- num_outputs,
- num_breaks=10,
- verify_exhausted=True):
+ def verify_multiple_breaks(self, ds_fn, num_outputs, num_breaks=10):
"""Attempts to save/restore at multiple break points.
Args:
@@ -139,22 +128,16 @@ class DatasetSerializationTestBase(test.TestCase):
num_outputs: See `run_core_tests`.
num_breaks: The number of break points. These are uniformly spread in
[0, num_outputs] both inclusive.
- verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs),
- num_outputs,
- verify_exhausted=verify_exhausted)
+ self.verify_run_with_breaks(ds_fn,
+ self.gen_break_points(num_outputs, num_breaks),
+ num_outputs)
- def verify_reset_restored_iterator(self,
- ds_fn,
- num_outputs,
- break_point=None,
- verify_exhausted=True):
+ def verify_reset_restored_iterator(self, ds_fn, num_outputs,
+ break_point=None):
"""Attempts to re-initialize a restored iterator.
This is useful when restoring a training checkpoint during validation.
@@ -163,7 +146,6 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
break_point: Break point. Optional. Defaults to num_outputs/2.
- verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
@@ -171,8 +153,7 @@ class DatasetSerializationTestBase(test.TestCase):
break_point = num_outputs // 2 if not break_point else break_point
# Collect ground truth containing all outputs.
- expected = self.gen_outputs(
- ds_fn, [], num_outputs, verify_exhausted=verify_exhausted)
+ expected = self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True)
# Skip some items and save checkpoint.
self.gen_outputs(ds_fn, [], break_point, verify_exhausted=False)
@@ -187,17 +168,15 @@ class DatasetSerializationTestBase(test.TestCase):
sess.run(init_op)
for _ in range(num_outputs):
actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
self.match(expected, actual)
def verify_restore_in_modified_graph(self,
ds_fn1,
ds_fn2,
num_outputs,
- break_point=None,
- verify_exhausted=True):
+ break_point=None):
"""Attempts to restore an iterator in a modified graph.
Builds an input pipeline using ds_fn1, runs it for `break_point` steps
@@ -209,7 +188,6 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn2: See `run_core_tests`.
num_outputs: See `run_core_tests`.
break_point: Break point. Optional. Defaults to num_outputs/2.
- verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
@@ -218,15 +196,15 @@ class DatasetSerializationTestBase(test.TestCase):
# Skip `break_point` items and store the remaining produced from ds_fn1
# in `expected`.
- self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False)
+ self.gen_outputs(ds_fn1, [], break_point)
expected = self.gen_outputs(
ds_fn1, [],
num_outputs - break_point,
ckpt_saved=True,
- verify_exhausted=verify_exhausted)
+ verify_exhausted=True)
# Generate `break_point` items from ds_fn1 and save checkpoint.
- self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False)
+ self.gen_outputs(ds_fn1, [], break_point)
actual = []
# Build graph for ds_fn2 but load checkpoint for ds_fn1.
@@ -236,9 +214,8 @@ class DatasetSerializationTestBase(test.TestCase):
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
self.match(expected, actual)
@@ -246,7 +223,6 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn,
break_points,
num_outputs,
- verify_exhausted=True,
init_before_restore=False):
"""Verifies that ds_fn() produces the same outputs with and without breaks.
@@ -261,7 +237,6 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn: See `gen_outputs`.
break_points: See `gen_outputs`.
num_outputs: See `gen_outputs`.
- verify_exhausted: See `gen_outputs`.
init_before_restore: See `gen_outputs`.
Raises:
@@ -270,13 +245,13 @@ class DatasetSerializationTestBase(test.TestCase):
expected = self.gen_outputs(
ds_fn, [],
num_outputs,
- verify_exhausted=verify_exhausted,
+ verify_exhausted=True,
init_before_restore=init_before_restore)
actual = self.gen_outputs(
ds_fn,
break_points,
num_outputs,
- verify_exhausted=verify_exhausted,
+ verify_exhausted=True,
init_before_restore=init_before_restore)
self.match(expected, actual)
@@ -286,7 +261,7 @@ class DatasetSerializationTestBase(test.TestCase):
num_outputs,
ckpt_saved=False,
init_before_restore=False,
- verify_exhausted=True):
+ verify_exhausted=False):
"""Generates elements from input dataset while stopping at break points.
Produces `num_outputs` outputs and saves the state of the iterator in the
@@ -310,7 +285,7 @@ class DatasetSerializationTestBase(test.TestCase):
after producing `num_outputs` elements.
Returns:
- A list of `num_outputs` items.
+ A list if `num_outputs` items.
"""
outputs = []
@@ -337,11 +312,11 @@ class DatasetSerializationTestBase(test.TestCase):
num_iters = end - start
for _ in range(num_iters):
outputs.append(sess.run(get_next_op))
+ self._save(sess, saver)
+ ckpt_saved = True
if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
- self._save(sess, saver)
- ckpt_saved = True
return outputs