From 52dc7286bda07a53b4bc6e5ca17ff22fc5d72af5 Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Wed, 1 Aug 2018 14:52:59 -0700 Subject: [Checkpointable] Make Iterator checkpointable. Use object-based save/restore to make dataset/iterator checkpointable in both graph as well as eager mode. PiperOrigin-RevId: 206998349 --- tensorflow/contrib/eager/python/datasets.py | 32 +------- tensorflow/contrib/eager/python/datasets_test.py | 13 +++ tensorflow/python/BUILD | 36 ++++++++ tensorflow/python/data/kernel_tests/BUILD | 4 + .../python/data/kernel_tests/iterator_ops_test.py | 96 ++++++++++++++++++++++ tensorflow/python/data/ops/BUILD | 2 + tensorflow/python/data/ops/iterator_ops.py | 38 ++++++++- .../api/golden/tensorflow.data.-iterator.pbtxt | 1 + 8 files changed, 189 insertions(+), 33 deletions(-) (limited to 'tensorflow') diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index e31dbbe80f..16844e0d68 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -22,12 +22,9 @@ from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.saver import BaseSaverBuilder -class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): +class Iterator(iterator_ops.EagerIterator): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the @@ -82,30 +79,3 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): return super(Iterator, self)._next_internal() - - # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset - # attributes(potential). - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" - - def __init__(self, iterator_resource, name): - serialized_iterator = gen_dataset_ops.serialize_iterator( - iterator_resource) - specs = [ - BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") - ] - # pylint: disable=protected-access - super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - - def restore(self, restored_tensors, restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, - restored_tensors[0]) - - def _gather_saveables_for_checkpoint(self): - - def _saveable_factory(name): - return self._Saveable(self._resource, name) - - return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index acc605247f..2917eaac97 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -306,6 +307,18 @@ class IteratorTest(test.TestCase): checkpoint.restore(save_path) self.assertEqual(2, iterator.get_next().numpy()) + def testRestoreInReconstructedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(10) + for i in range(5): + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint.restore(saver.latest_checkpoint(checkpoint_directory)) + for j in range(2): + self.assertEqual(i * 2 + j, iterator.get_next().numpy()) + checkpoint.save(file_prefix=checkpoint_prefix) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index d35731d3cd..2ccaae4dcf 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3216,6 +3216,7 @@ py_library( # The following targets have their own build rules (same name as the # file): "training/saveable_object.py", + "training/saver.py", "training/training_util.py", ], ), @@ -3247,6 +3248,7 @@ py_library( ":random_ops", ":resource_variable_ops", ":resources", + "saver", ":saveable_object", ":sdca_ops", ":sparse_ops", @@ -3277,6 +3279,40 @@ py_library( srcs_version = "PY2AND3", ) +py_library( + name = "saver", + srcs = ["training/saver.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":constant_op", + ":control_flow_ops", + ":device", + ":errors", + ":framework", + ":framework_ops", + ":io_ops", + ":io_ops_gen", + ":lib", + ":platform", + ":protos_all_py", + ":pywrap_tensorflow", + ":resource_variable_ops", + ":saveable_object", + ":session", + ":state_ops", + ":string_ops", + ":training_util", + ":util", + ":variables", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "device_util", srcs = ["training/device_util.py"], diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index b66b87ce6c..a6b89ce102 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -329,6 +329,8 @@ cuda_py_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:util", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -350,6 +352,8 @@ cuda_py_test( "//tensorflow/python:tensor_shape", "//tensorflow/python:training", "//tensorflow/python/compat:compat", + "//tensorflow/python:util", + "//tensorflow/python:variables", ], grpc_enabled = True, ) diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index b434fa7334..dd39262f9b 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import os import warnings @@ -46,7 +47,9 @@ from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import saver from tensorflow.python.training import server_lib +from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat @@ -788,5 +791,98 @@ class IteratorTest(test.TestCase): val += 1 +class IteratorCheckpointingTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testSaveRestoreOneShotIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map( + math_ops.square).batch(2) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator.get_next()) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + with self.test_session() as sess: + self.assertAllEqual([1, 4], get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + checkpoint.restore(save_path).run_restore_ops(sess) + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + with self.assertRaises(errors.OutOfRangeError): + get_next() + + @test_util.run_in_graph_and_eager_modes + def testSaveRestoreMultipleIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.from_tensor_slices( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator_1 = dataset.make_one_shot_iterator() + get_next_1 = iterator_1.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator_1.get_next()) + iterator_2 = dataset.make_one_shot_iterator() + get_next_2 = iterator_2.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator_2.get_next()) + dataset_2 = dataset_ops.Dataset.range(10) + iterator_3 = dataset_2.make_one_shot_iterator() + get_next_3 = iterator_3.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator_3.get_next()) + checkpoint = checkpointable_utils.Checkpoint( + iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) + with self.test_session() as sess: + self.assertAllEqual([1, 4], get_next_1()) + self.assertAllEqual(0, get_next_3()) + self.assertAllEqual(1, get_next_3()) + self.assertAllEqual(2, get_next_3()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual([9, 16], get_next_2()) + self.assertAllEqual(3, get_next_3()) + checkpoint.restore(save_path).run_restore_ops(sess) + self.assertAllEqual([9, 16], get_next_1()) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual(3, get_next_3()) + + @test_util.run_in_graph_and_eager_modes + def testRestoreExhaustedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.range(3) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator.get_next()) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + with self.test_session() as sess: + self.assertAllEqual(0, get_next()) + self.assertAllEqual(1, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual(2, get_next()) + checkpoint.restore(save_path).run_restore_ops(sess) + self.assertAllEqual(2, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + checkpoint.restore(save_path).run_restore_ops(sess) + with self.assertRaises(errors.OutOfRangeError): + get_next() + + def testRestoreInReconstructedIteratorInitializable(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.range(10) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + for i in range(5): + with self.test_session() as sess: + checkpoint.restore(saver.latest_checkpoint( + checkpoint_directory)).initialize_or_restore(sess) + for j in range(2): + self.assertEqual(i * 2 + j, sess.run(get_next)) + checkpoint.save(file_prefix=checkpoint_prefix) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index f15eb6310f..4b393abf02 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -54,10 +54,12 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:saver", "//tensorflow/python:tensor_shape", "//tensorflow/python/compat", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", ], ) diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 494df178df..7eb64390ef 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -30,6 +30,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.util.tf_export import tf_export @@ -65,7 +67,7 @@ def _device_stack_is_empty(): @tf_export("data.Iterator") -class Iterator(object): +class Iterator(checkpointable.CheckpointableBase): """Represents the state of iterating through a `Dataset`.""" def __init__(self, iterator_resource, initializer, output_types, @@ -464,6 +466,13 @@ class Iterator(object): """ return self._output_types + def _gather_saveables_for_checkpoint(self): + + def _saveable_factory(name): + return _IteratorSaveable(self._iterator_resource, name) + + return {"ITERATOR": _saveable_factory} + _uid_counter = 0 _uid_lock = threading.Lock() @@ -477,7 +486,7 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) -class EagerIterator(object): +class EagerIterator(checkpointable.CheckpointableBase): """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" def __init__(self, dataset): @@ -610,3 +619,28 @@ class EagerIterator(object): """ del name return self._next_internal() + + def _gather_saveables_for_checkpoint(self): + + def _saveable_factory(name): + return _IteratorSaveable(self._resource, name) + + return {"ITERATOR": _saveable_factory} + + +# TODO(b/71645805): Expose checkpointable stateful objects from dataset +# attributes(potential). +class _IteratorSaveable(BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" + + def __init__(self, iterator_resource, name): + serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) + specs = [ + BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") + ] + # pylint: disable=protected-access + super(_IteratorSaveable, self).__init__(iterator_resource, specs, name) + + def restore(self, restored_tensors, restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) diff --git a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt index 1f9aeb6ad6..4f0147a523 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.data.Iterator" tf_class { is_instance: "" + is_instance: "" is_instance: "" member { name: "initializer" -- cgit v1.2.3