aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/data/__init__.py1
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD21
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py169
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops_test.py123
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py7
5 files changed, 314 insertions, 7 deletions
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 077cbba9d2..4f2c72b660 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -72,6 +72,7 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
+from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 5b04c5316c..144460fde0 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -45,6 +45,27 @@ py_library(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:model_fn",
],
)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index d736029fb0..f1d0e5cddc 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -16,10 +16,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training import saver
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import session_run_hook
def make_saveable_from_iterator(iterator):
@@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator):
return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
-class _Saveable(saver.BaseSaverBuilder.SaveableObject):
+class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""
def __init__(self, iterator_resource):
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
specs = [
- saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
- iterator_resource.name + "-state")
+ saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
+ iterator_resource.name + "-state")
]
super(_Saveable, self).__init__(iterator_resource, specs,
iterator_resource.name)
@@ -75,3 +77,160 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject):
def restore(self, restored_tensors, unused_restored_shapes):
with ops.colocate_with(self.op):
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+ """Checkpoints input pipeline state every N steps or seconds.
+
+ This hook saves the state of the iterators in the `Graph` so that when
+ training is resumed the input pipeline continues from where it left off.
+ This could potentially avoid overfitting in certain pipelines where the
+ number of training steps per eval are small compared to the dataset
+ size or if the training pipeline is pre-empted.
+
+ Differences from `CheckpointSaverHook`:
+ 1. Saves only the input pipelines in the "iterators" collection and not the
+ global variables or other saveable objects.
+ 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
+
+ Example of checkpointing the training pipeline:
+
+ ```python
+ est = tf.estimator.Estimator(model_fn)
+ while True:
+ est.train(
+ train_input_fn,
+ hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)],
+ steps=train_steps_per_eval)
+ # Note: We do not pass the hook here.
+ metrics = est.evaluate(eval_input_fn)
+ if should_stop_the_training(metrics):
+ break
+ ```
+
+ This hook should be used if the input pipeline state needs to be saved
+ separate from the model checkpoint. Doing so may be useful for a few reasons:
+ 1. The input pipeline checkpoint may be large, if there are large shuffle
+ or prefetch buffers for instance, and may bloat the checkpoint size.
+ 2. If the input pipeline is shared between training and validation, restoring
+ the checkpoint during validation may override the validation input
+ pipeline.
+
+ For saving the input pipeline checkpoint alongside the model weights use
+ @{tf.contrib.data.make_saveable_from_iterator} directly to create a
+ `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
+ that you will need to be careful not to restore the training iterator during
+ eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
+ collector when building the eval graph.
+ """
+
+ def __init__(self, estimator):
+ """Initializes a `CheckpointInputPipelineHook`.
+
+ Args:
+ estimator: Estimator.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of saver or scaffold should be set.
+ """
+ # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
+ # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
+ # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
+ # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
+ # to be different to avoid conflicts with the model checkpoint.
+
+ # pylint: disable=protected-access
+ checkpoint_prefix = "input"
+ if estimator._config.num_worker_replicas > 1:
+ # Distributed setting.
+ suffix = "_{}_{}".format(estimator._config.task_type,
+ estimator._config.task_id)
+ checkpoint_prefix += suffix
+ # pylint: enable=protected-access
+
+ # We use a composition paradigm instead of inheriting from
+ # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
+ # to check whether a `CheckpointSaverHook` is already present in the list
+ # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
+ # would thwart this behavior. This hook checkpoints *only the iterators*
+ # and not the graph variables.
+ self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
+ estimator.model_dir,
+ save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
+ save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
+ checkpoint_basename=checkpoint_prefix + ".ckpt")
+
+ # Name for the protocol buffer file that will contain the list of most
+ # recent checkpoints stored as a `CheckpointState` protocol buffer.
+ # This file, kept in the same directory as the checkpoint files, is
+ # automatically managed by the `Saver` to keep track of recent checkpoints.
+ # The default name used by the `Saver` for this file is "checkpoint". Here
+ # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
+ # `checkpoint_dir` is the same as the model checkpoint directory, there are
+ # no conflicts during restore.
+ self._latest_filename = "checkpoint_" + checkpoint_prefix
+
+ def begin(self):
+ # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
+ # collection if no `Saver` or `Scaffold` is provided.
+ # pylint: disable=protected-access
+ if (self._checkpoint_saver_hook._saver is None and
+ self._checkpoint_saver_hook._scaffold is None):
+ iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
+ saveables = [_Saveable(i) for i in iterators]
+ self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
+ self._latest_filename)
+ # pylint: enable=protected-access
+ self._checkpoint_saver_hook.begin()
+
+ def after_create_session(self, session, coord):
+ # Check if there is an existing checkpoint. If so, restore from it.
+ # pylint: disable=protected-access
+ latest_checkpoint_path = saver_lib.latest_checkpoint(
+ self._checkpoint_saver_hook._checkpoint_dir,
+ latest_filename=self._latest_filename)
+ if latest_checkpoint_path:
+ self._checkpoint_saver_hook._get_saver().restore(session,
+ latest_checkpoint_path)
+ else:
+ # The checkpoint saved here is the state at step "global_step".
+ # Note: We do not save the GraphDef or MetaGraphDef here.
+ global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
+ self._checkpoint_saver_hook._save(session, global_step)
+ self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
+ # pylint: enable=protected-access
+
+ def before_run(self, run_context):
+ return self._checkpoint_saver_hook.before_run(run_context)
+
+ def after_run(self, run_context, run_values):
+ self._checkpoint_saver_hook.after_run(run_context, run_values)
+
+ def end(self, session):
+ self._checkpoint_saver_hook.end(session)
+
+
+class _CustomSaver(saver_lib.Saver):
+ """`Saver` with a different default `latest_filename`.
+
+ This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
+ the model ckpt saved by the `CheckpointSaverHook`.
+ """
+
+ def __init__(self, var_list, latest_filename):
+ super(_CustomSaver, self).__init__(var_list)
+ self._latest_filename = latest_filename
+
+ def save(self,
+ sess,
+ save_path,
+ global_step=None,
+ latest_filename=None,
+ meta_graph_suffix="meta",
+ write_meta_graph=True,
+ write_state=True,
+ strip_default_attrs=False):
+ return super(_CustomSaver, self).save(
+ sess, save_path, global_step, latest_filename or self._latest_filename,
+ meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py
new file mode 100644
index 0000000000..30a993b1f7
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental iterator_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import training_util
+
+
+class CheckpointInputPipelineHookTest(test.TestCase):
+
+ @staticmethod
+ def _model_fn(features, labels, mode, config):
+ del labels
+ del mode
+ del config
+ global_step = training_util.get_or_create_global_step()
+ update_global_step_op = global_step.assign_add(1)
+ latest_feature = variables.Variable(
+ 0, name='latest_feature', dtype=dtypes.int64)
+ store_latest_feature_op = latest_feature.assign(features)
+ ops.add_to_collection('my_vars', global_step)
+ ops.add_to_collection('my_vars', latest_feature)
+ return model_fn.EstimatorSpec(
+ mode='train',
+ train_op=control_flow_ops.group(
+ [update_global_step_op, store_latest_feature_op]),
+ loss=constant_op.constant(2.0))
+
+ def _read_vars(self, model_dir):
+ """Returns (global_step, latest_feature)."""
+ with ops.Graph().as_default() as g:
+ ckpt_path = saver_lib.latest_checkpoint(model_dir)
+ meta_filename = ckpt_path + '.meta'
+ saver_lib.import_meta_graph(meta_filename)
+ saver = saver_lib.Saver()
+ with self.test_session(graph=g) as sess:
+ saver.restore(sess, ckpt_path)
+ return sess.run(ops.get_collection('my_vars'))
+
+ def _build_iterator_saver_hook(self, est):
+ return iterator_ops.CheckpointInputPipelineHook(est)
+
+ def testReturnDatasetFromInputFn(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testBuildIteratorInInputFn(self):
+
+ def _input_fn():
+ ds = dataset_ops.Dataset.range(10)
+ iterator = ds.make_one_shot_iterator()
+ return iterator.get_next()
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testDoNotRestore(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+ # Hook not provided, input pipeline was not restored.
+ est.train(_input_fn, steps=2)
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
+
+ def testRaiseErrorIfNoIterator(self):
+
+ def _input_fn():
+ return constant_op.constant(1, dtype=dtypes.int64)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ with self.assertRaises(ValueError):
+ est.train(
+ _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 0c76afd29d..fd164277b6 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -52,6 +52,9 @@ GET_NEXT_CALL_WARNING_MESSAGE = (
"`next_element` as the input to some computation that is invoked inside "
"the loop.")
+# Collection of all IteratorResources in the `Graph`.
+GLOBAL_ITERATORS = "iterators"
+
@tf_export("data.Iterator")
class Iterator(object):
@@ -75,8 +78,7 @@ class Iterator(object):
output_shapes: A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this dataset.
output_classes: A nested structure of Python `type` object corresponding
- to each
- component of an element of this iterator.
+ to each component of an element of this iterator.
"""
self._iterator_resource = iterator_resource
self._initializer = initializer
@@ -86,6 +88,7 @@ class Iterator(object):
self._string_handle = gen_dataset_ops.iterator_to_string_handle(
self._iterator_resource)
self._get_next_call_count = 0
+ ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
@staticmethod
def from_structure(output_types,