aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-08-01 14:52:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 15:01:24 -0700
commit52dc7286bda07a53b4bc6e5ca17ff22fc5d72af5 (patch)
treee4d230a155a10ef237b8fc05e531f98f39349d56 /tensorflow/python
parenta28ad4b26dbb8cb1e9cf2135f72f3f55ffabf037 (diff)
[Checkpointable] Make Iterator checkpointable.
Use object-based save/restore to make dataset/iterator checkpointable in both graph as well as eager mode. PiperOrigin-RevId: 206998349
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD36
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD4
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py96
-rw-r--r--tensorflow/python/data/ops/BUILD2
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py38
5 files changed, 174 insertions, 2 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index d35731d3cd..2ccaae4dcf 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3216,6 +3216,7 @@ py_library(
# The following targets have their own build rules (same name as the
# file):
"training/saveable_object.py",
+ "training/saver.py",
"training/training_util.py",
],
),
@@ -3247,6 +3248,7 @@ py_library(
":random_ops",
":resource_variable_ops",
":resources",
+ "saver",
":saveable_object",
":sdca_ops",
":sparse_ops",
@@ -3278,6 +3280,40 @@ py_library(
)
py_library(
+ name = "saver",
+ srcs = ["training/saver.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":constant_op",
+ ":control_flow_ops",
+ ":device",
+ ":errors",
+ ":framework",
+ ":framework_ops",
+ ":io_ops",
+ ":io_ops_gen",
+ ":lib",
+ ":platform",
+ ":protos_all_py",
+ ":pywrap_tensorflow",
+ ":resource_variable_ops",
+ ":saveable_object",
+ ":session",
+ ":state_ops",
+ ":string_ops",
+ ":training_util",
+ ":util",
+ ":variables",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:base",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "device_util",
srcs = ["training/device_util.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index b66b87ce6c..a6b89ce102 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -329,6 +329,8 @@ cuda_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -350,6 +352,8 @@ cuda_py_test(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
"//tensorflow/python/compat:compat",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
],
grpc_enabled = True,
)
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index b434fa7334..dd39262f9b 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import os
import warnings
@@ -46,7 +47,9 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
@@ -788,5 +791,98 @@ class IteratorTest(test.TestCase):
val += 1
+class IteratorCheckpointingTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestoreOneShotIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(
+ math_ops.square).batch(2)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ with self.test_session() as sess:
+ self.assertAllEqual([1, 4], get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestoreMultipleIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+ dataset = dataset.map(math_ops.square).batch(2)
+ iterator_1 = dataset.make_one_shot_iterator()
+ get_next_1 = iterator_1.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_1.get_next())
+ iterator_2 = dataset.make_one_shot_iterator()
+ get_next_2 = iterator_2.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_2.get_next())
+ dataset_2 = dataset_ops.Dataset.range(10)
+ iterator_3 = dataset_2.make_one_shot_iterator()
+ get_next_3 = iterator_3.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_3.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(
+ iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
+ with self.test_session() as sess:
+ self.assertAllEqual([1, 4], get_next_1())
+ self.assertAllEqual(0, get_next_3())
+ self.assertAllEqual(1, get_next_3())
+ self.assertAllEqual(2, get_next_3())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual([9, 16], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ self.assertAllEqual([9, 16], get_next_1())
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testRestoreExhaustedIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.range(3)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ with self.test_session() as sess:
+ self.assertAllEqual(0, get_next())
+ self.assertAllEqual(1, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual(2, get_next())
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ self.assertAllEqual(2, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
+
+ def testRestoreInReconstructedIteratorInitializable(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.range(10)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ for i in range(5):
+ with self.test_session() as sess:
+ checkpoint.restore(saver.latest_checkpoint(
+ checkpoint_directory)).initialize_or_restore(sess)
+ for j in range(2):
+ self.assertEqual(i * 2 + j, sess.run(get_next))
+ checkpoint.save(file_prefix=checkpoint_prefix)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index f15eb6310f..4b393abf02 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -54,10 +54,12 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:saver",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:base",
],
)
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 494df178df..7eb64390ef 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -30,6 +30,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util.tf_export import tf_export
@@ -65,7 +67,7 @@ def _device_stack_is_empty():
@tf_export("data.Iterator")
-class Iterator(object):
+class Iterator(checkpointable.CheckpointableBase):
"""Represents the state of iterating through a `Dataset`."""
def __init__(self, iterator_resource, initializer, output_types,
@@ -464,6 +466,13 @@ class Iterator(object):
"""
return self._output_types
+ def _gather_saveables_for_checkpoint(self):
+
+ def _saveable_factory(name):
+ return _IteratorSaveable(self._iterator_resource, name)
+
+ return {"ITERATOR": _saveable_factory}
+
_uid_counter = 0
_uid_lock = threading.Lock()
@@ -477,7 +486,7 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
-class EagerIterator(object):
+class EagerIterator(checkpointable.CheckpointableBase):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
def __init__(self, dataset):
@@ -610,3 +619,28 @@ class EagerIterator(object):
"""
del name
return self._next_internal()
+
+ def _gather_saveables_for_checkpoint(self):
+
+ def _saveable_factory(name):
+ return _IteratorSaveable(self._resource, name)
+
+ return {"ITERATOR": _saveable_factory}
+
+
+# TODO(b/71645805): Expose checkpointable stateful objects from dataset
+# attributes(potential).
+class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
+ """SaveableObject for saving/restoring iterator state."""
+
+ def __init__(self, iterator_resource, name):
+ serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
+ specs = [
+ BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
+ ]
+ # pylint: disable=protected-access
+ super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ with ops.colocate_with(self.op):
+ return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])