aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/reader_dataset_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/reader_dataset_ops_test.py298
1 files changed, 0 insertions, 298 deletions
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index c8e7333b4b..70b6ce442e 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -26,13 +26,8 @@ from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -272,299 +267,6 @@ class FixedLengthRecordReaderTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
- def _iterator_checkpoint_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(self, iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- self._iterator_checkpoint_path(),
- parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(self, iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def _build_iterator_graph(self, num_epochs):
- filenames = self._createFiles()
- dataset = (readers.FixedLengthRecordDataset(
- filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
- .repeat(num_epochs))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next_op = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next_op, save_op, restore_op
-
- def _restore_iterator(self):
- output_types = dtypes.string
- output_shapes = tensor_shape.scalar()
- iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
- get_next = iterator.get_next()
- restore_op = self._restore_op(iterator._iterator_resource)
- return restore_op, get_next
-
- def testSaveRestore(self):
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testInitThenRestore(self):
- # Note: Calling init_op before restore_op is redundant. This test just makes
- # sure we do not fail if restore is called on an already initialized
- # iterator resource.
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreInModifiedGraph(self):
- num_epochs = 10
- num_epochs_1 = 20
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs_1)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreWithoutBuildingDatasetGraph(self):
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- restore_op, get_next_op = self._restore_iterator()
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreUnusedIterator(self):
- num_epochs = 10
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- # Save unused iterator.
- sess.run(save_op)
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for _ in range(num_epochs * self._num_files * self._num_records):
- sess.run(get_next_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreExhaustedIterator(self):
- num_epochs = 10
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for _ in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
class TFRecordDatasetTest(test.TestCase):