aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-09-26 14:36:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 14:41:28 -0700
commita1801ecdbb75b4583d757204611afd9af28b4a49 (patch)
tree97f8e453de9d819a33bd4d54bc87278814c56fea
parent2116c6649cfe339ce8a3859eb425806db8ae32b9 (diff)
Add experimental asynchronous checkpoint hook.
This triggers checkpoints in a separate thread while allowing training to continue. This can effectively parallelize checkpointing and training for workloads like TPUEstimator, where the weights are only updated after a number of device iterations. PiperOrigin-RevId: 214670991
-rw-r--r--tensorflow/contrib/tpu/BUILD22
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/async_checkpoint.py202
3 files changed, 225 insertions, 0 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 4e0b61227e..8355c92a4d 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -36,6 +36,27 @@ cc_library(
)
py_library(
+ name = "async_checkpoint",
+ srcs = ["python/tpu/async_checkpoint.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:summary_ops_v2",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_library(
name = "tpu_estimator",
srcs = [
"python/tpu/error_handling.py",
@@ -46,6 +67,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":async_checkpoint",
":tpu_lib",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 3c0456dc2f..766466968a 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -55,6 +55,7 @@
@@TPUDistributionStrategy
@@keras_to_tpu_model
+@@AsyncCheckpointSaverHook
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
new file mode 100644
index 0000000000..e06a720e82
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
@@ -0,0 +1,202 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+
+"""Hook for asynchronous checkpointing.
+
+This hook dispatches checkpoint writing operations in a separate thread to
+allow execution to continue on the main thread.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+import time
+
+from tensorflow.core.util.event_pb2 import SessionLog
+
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+from tensorflow.python.training.session_run_hook import SessionRunArgs
+from tensorflow.python.training.summary_io import SummaryWriterCache
+
+
+class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
+ """Saves checkpoints every N steps or seconds."""
+
+ def __init__(self,
+ checkpoint_dir,
+ save_secs=None,
+ save_steps=None,
+ saver=None,
+ checkpoint_basename="model.ckpt",
+ scaffold=None,
+ listeners=None):
+ """Initializes a `CheckpointSaverHook`.
+
+ Args:
+ checkpoint_dir: `str`, base directory for the checkpoint files.
+ save_secs: `int`, save every N secs.
+ save_steps: `int`, save every N steps.
+ saver: `Saver` object, used for saving.
+ checkpoint_basename: `str`, base name for the checkpoint files.
+ scaffold: `Scaffold`, use to get saver object.
+ listeners: List of `CheckpointSaverListener` subclass instances. Used for
+ callbacks that run immediately before or after this hook saves the
+ checkpoint.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of `saver` or `scaffold` should be set.
+ """
+ logging.info("Create CheckpointSaverHook.")
+ if saver is not None and scaffold is not None:
+ raise ValueError("You cannot provide both saver and scaffold.")
+ self._saver = saver
+ self._save_thread = None
+ self._checkpoint_dir = checkpoint_dir
+ self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
+ self._scaffold = scaffold
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=save_secs, every_steps=save_steps)
+ self._listeners = listeners or []
+ self._steps_per_run = 1
+ self._summary_writer = None
+ self._global_step_tensor = None
+
+ def _set_steps_per_run(self, steps_per_run):
+ self._steps_per_run = steps_per_run
+
+ def begin(self):
+ self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
+ self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ if self._global_step_tensor is None:
+ raise RuntimeError(
+ "Global step should be created to use CheckpointSaverHook.")
+ for l in self._listeners:
+ l.begin()
+
+ def after_create_session(self, session, coord):
+ global_step = session.run(self._global_step_tensor)
+
+ # We do write graph and saver_def at the first call of before_run.
+ # We cannot do this in begin, since we let other hooks to change graph and
+ # add variables in begin. Graph is finalized after all begin calls.
+ training_util.write_graph(
+ ops.get_default_graph().as_graph_def(add_shapes=True),
+ self._checkpoint_dir, "graph.pbtxt")
+ saver_def = self._get_saver().saver_def if self._get_saver() else None
+ graph = ops.get_default_graph()
+ meta_graph_def = meta_graph.create_meta_graph_def(
+ graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
+ self._summary_writer.add_graph(graph)
+ self._summary_writer.add_meta_graph(meta_graph_def)
+ # The checkpoint saved here is the state at step "global_step".
+ self._save(session, global_step)
+ self._timer.update_last_triggered_step(global_step)
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ stale_global_step = run_values.results
+ if self._timer.should_trigger_for_step(stale_global_step +
+ self._steps_per_run):
+ # get the real value after train op.
+ global_step = run_context.session.run(self._global_step_tensor)
+ if self._timer.should_trigger_for_step(global_step):
+ self._timer.update_last_triggered_step(global_step)
+ if self._save(run_context.session, global_step):
+ run_context.request_stop()
+
+ def end(self, session):
+ if self._save_thread:
+ logging.info("Waiting for any pending checkpoints to finish.")
+ self._save_thread.join()
+
+ last_step = session.run(self._global_step_tensor)
+
+ # Save the last checkpoint synchronously if needed.
+ if last_step != self._timer.last_triggered_step():
+ self._save(session, last_step, asynchronous=False)
+
+ for l in self._listeners:
+ l.end(session, last_step)
+
+ def _save(self, session, step, asynchronous=True):
+ """Saves the latest checkpoint, returns should_stop."""
+
+ def _save_fn():
+ """Run the saver process."""
+ logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
+
+ start_time = time.time()
+ for l in self._listeners:
+ l.before_save(session, step)
+
+ self._get_saver().save(session, self._save_path, global_step=step)
+ self._summary_writer.add_session_log(
+ SessionLog(
+ status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
+ step)
+ end_time = time.time()
+ logging.info("Checkpoint actual writing time: (%.3f sec)",
+ end_time - start_time)
+ logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
+
+ logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
+ for l in self._listeners:
+ l.before_save(session, step)
+
+ if not asynchronous:
+ _save_fn()
+ return
+
+ if self._save_thread is not None:
+ self._save_thread.join(timeout=0.1)
+ if self._save_thread.is_alive():
+ logging.info("Saver thread still in progress, skipping checkpoint.")
+ return
+
+ self._save_thread = threading.Thread(target=_save_fn)
+ self._save_thread.start()
+
+ def _get_saver(self):
+ if self._saver is not None:
+ return self._saver
+ elif self._scaffold is not None:
+ return self._scaffold.saver
+
+ # Get saver from the SAVERS collection if present.
+ collection_key = ops.GraphKeys.SAVERS
+ savers = ops.get_collection(collection_key)
+ if not savers:
+ raise RuntimeError(
+ "No items in collection {}. Please add a saver to the collection "
+ "or provide a saver or scaffold.".format(collection_key))
+ elif len(savers) > 1:
+ raise RuntimeError(
+ "More than one item in collection {}. "
+ "Please indicate which one to use by passing it to the constructor."
+ .format(collection_key))
+
+ self._saver = savers[0]
+ return savers[0]