diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-08-01 14:52:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-01 15:01:24 -0700 |
commit | 52dc7286bda07a53b4bc6e5ca17ff22fc5d72af5 (patch) | |
tree | e4d230a155a10ef237b8fc05e531f98f39349d56 /tensorflow/contrib/eager | |
parent | a28ad4b26dbb8cb1e9cf2135f72f3f55ffabf037 (diff) |
[Checkpointable] Make Iterator checkpointable.
Use object-based save/restore to make dataset/iterator checkpointable in both graph as well as eager mode.
PiperOrigin-RevId: 206998349
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/datasets.py | 32 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/datasets_test.py | 13 |
2 files changed, 14 insertions, 31 deletions
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index e31dbbe80f..16844e0d68 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -22,12 +22,9 @@ from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.saver import BaseSaverBuilder -class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): +class Iterator(iterator_ops.EagerIterator): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the @@ -82,30 +79,3 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): return super(Iterator, self)._next_internal() - - # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset - # attributes(potential). - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" - - def __init__(self, iterator_resource, name): - serialized_iterator = gen_dataset_ops.serialize_iterator( - iterator_resource) - specs = [ - BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") - ] - # pylint: disable=protected-access - super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - - def restore(self, restored_tensors, restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, - restored_tensors[0]) - - def _gather_saveables_for_checkpoint(self): - - def _saveable_factory(name): - return self._Saveable(self._resource, name) - - return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index acc605247f..2917eaac97 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -306,6 +307,18 @@ class IteratorTest(test.TestCase): checkpoint.restore(save_path) self.assertEqual(2, iterator.get_next().numpy()) + def testRestoreInReconstructedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(10) + for i in range(5): + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint.restore(saver.latest_checkpoint(checkpoint_directory)) + for j in range(2): + self.assertEqual(i * 2 + j, iterator.get_next().numpy()) + checkpoint.save(file_prefix=checkpoint_prefix) + class DatasetConstructorBenchmark(test.Benchmark): |