diff options
author | 2017-10-23 11:07:10 -0700 | |
---|---|---|
committer | 2017-10-23 11:11:17 -0700 | |
commit | 1038927c096ecc81ca48665871d1be390444b121 (patch) | |
tree | 40c7ff20843bc62f248153b5b85e8116e16c3f4c /tensorflow/python/kernel_tests | |
parent | 57f3e529d935e6b08a6c0a3a418ad367d9314fde (diff) |
Add SerializeIterator op that serializes an IteratorResource into a variant tensor.
Add DeserializeIterator op that builds IteratorResource from a variant tensor.
Move BundleReaderWrapper and BundleWriterWrapper from dataset.h to iterator_ops.cc.
Add generic key-value store interfaces IteratorStateReader and IteratorStateWriter for reading/writing state of iterators.
Get rid of IteratorBundleReader and IteratorBundleWriter.
PiperOrigin-RevId: 173140858
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/iterator_ops_test.py | 29 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/range_dataset_op_test.py | 67 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/reader_dataset_ops_test.py | 26 |
4 files changed, 79 insertions, 48 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 0e36c3498a..b02bae95fd 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2886,7 +2886,9 @@ tf_py_test( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:io_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", "//tensorflow/python:variables", @@ -2907,7 +2909,9 @@ tf_py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", @@ -3022,6 +3026,7 @@ tf_py_test( "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:gradients", + "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:script_ops", diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py index b5ec9f7db0..2128ef4ae1 100644 --- a/tensorflow/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/kernel_tests/iterator_ops_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops @@ -538,9 +539,23 @@ class IteratorTest(test.TestCase): def testIncorrectIteratorRestore(self): - def _iterator_checkpoint_prefix(): + def _path(): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + _path(), parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_range_dataset_graph(): start = 1 stop = 10 @@ -548,22 +563,18 @@ class IteratorTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = _iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. - path = _iterator_checkpoint_prefix() iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op # Saving iterator for RangeDataset graph. diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py index 8291967155..0c530522b8 100644 --- a/tensorflow/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py @@ -27,6 +27,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase): def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def testSaveRestore(self): def _build_graph(start, stop): @@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase): def testRestoreWithoutBuildingDatasetGraph(self): - def _build_graph(start, stop, num_epochs, path): + def _build_graph(start, stop, num_epochs): dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase): num_epochs = 5 break_point = 5 break_epoch = 3 - path = self._iterator_checkpoint_prefix() with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs, - path) + init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) with self.test_session(graph=g) as sess: sess.run(variables.global_variables_initializer()) sess.run(init_op) @@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + restore_op = self._restore_op(iterator._iterator_resource) get_next = iterator.get_next() with self.test_session(graph=g) as sess: sess.run(restore_op) @@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py index 38420328ef..c8e7333b4b 100644 --- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py @@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase): def _iterator_checkpoint_path(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_path(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() - path = self._iterator_checkpoint_path() dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op def _restore_iterator(self): @@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) get_next = iterator.get_next() - restore_op = gen_dataset_ops.restore_iterator( - iterator._iterator_resource, self._iterator_checkpoint_path()) + restore_op = self._restore_op(iterator._iterator_resource) return restore_op, get_next def testSaveRestore(self): |