aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-08-01 14:52:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 15:01:24 -0700
commit52dc7286bda07a53b4bc6e5ca17ff22fc5d72af5 (patch)
treee4d230a155a10ef237b8fc05e531f98f39349d56 /tensorflow/contrib/eager
parenta28ad4b26dbb8cb1e9cf2135f72f3f55ffabf037 (diff)
[Checkpointable] Make Iterator checkpointable.
Use object-based save/restore to make dataset/iterator checkpointable in both graph as well as eager mode. PiperOrigin-RevId: 206998349
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/datasets.py32
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py13
2 files changed, 14 insertions, 31 deletions
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index e31dbbe80f..16844e0d68 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -22,12 +22,9 @@ from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.saver import BaseSaverBuilder
-class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
+class Iterator(iterator_ops.EagerIterator):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
NOTE: Unlike the iterator created by the
@@ -82,30 +79,3 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
return super(Iterator, self)._next_internal()
-
- # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset
- # attributes(potential).
-
- class _Saveable(BaseSaverBuilder.SaveableObject):
- """SaveableObject for saving/restoring iterator state."""
-
- def __init__(self, iterator_resource, name):
- serialized_iterator = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- specs = [
- BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
- ]
- # pylint: disable=protected-access
- super(Iterator._Saveable, self).__init__(iterator_resource, specs, name)
-
- def restore(self, restored_tensors, restored_shapes):
- with ops.colocate_with(self.op):
- return gen_dataset_ops.deserialize_iterator(self.op,
- restored_tensors[0])
-
- def _gather_saveables_for_checkpoint(self):
-
- def _saveable_factory(name):
- return self._Saveable(self._resource, name)
-
- return {"ITERATOR": _saveable_factory}
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index acc605247f..2917eaac97 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
+from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -306,6 +307,18 @@ class IteratorTest(test.TestCase):
checkpoint.restore(save_path)
self.assertEqual(2, iterator.get_next().numpy())
+ def testRestoreInReconstructedIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ dataset = Dataset.range(10)
+ for i in range(5):
+ iterator = datasets.Iterator(dataset)
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ checkpoint.restore(saver.latest_checkpoint(checkpoint_directory))
+ for j in range(2):
+ self.assertEqual(i * 2 + j, iterator.get_next().numpy())
+ checkpoint.save(file_prefix=checkpoint_prefix)
+
class DatasetConstructorBenchmark(test.Benchmark):