From a1801ecdbb75b4583d757204611afd9af28b4a49 Mon Sep 17 00:00:00 2001 From: Russell Power Date: Wed, 26 Sep 2018 14:36:59 -0700 Subject: Add experimental asynchronous checkpoint hook. This triggers checkpoints in a separate thread while allowing training to continue. This can effectively parallelize checkpointing and training for workloads like TPUEstimator, where the weights are only updated after a number of device iterations. PiperOrigin-RevId: 214670991 --- tensorflow/contrib/tpu/BUILD | 22 +++ tensorflow/contrib/tpu/__init__.py | 1 + .../contrib/tpu/python/tpu/async_checkpoint.py | 202 +++++++++++++++++++++ 3 files changed, 225 insertions(+) create mode 100644 tensorflow/contrib/tpu/python/tpu/async_checkpoint.py diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 4e0b61227e..8355c92a4d 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -35,6 +35,27 @@ cc_library( ], ) +py_library( + name = "async_checkpoint", + srcs = ["python/tpu/async_checkpoint.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:summary_ops_v2", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:estimator_py", + ], +) + py_library( name = "tpu_estimator", srcs = [ @@ -46,6 +67,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":async_checkpoint", ":tpu_lib", "//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/compiler/xla/python_api:xla_shape", diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index 3c0456dc2f..766466968a 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -55,6 +55,7 @@ @@TPUDistributionStrategy @@keras_to_tpu_model +@@AsyncCheckpointSaverHook """ from __future__ import absolute_import diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py new file mode 100644 index 0000000000..e06a720e82 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -0,0 +1,202 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== + +"""Hook for asynchronous checkpointing. + +This hook dispatches checkpoint writing operations in a separate thread to +allow execution to continue on the main thread. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import threading +import time + +from tensorflow.core.util.event_pb2 import SessionLog + +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util +from tensorflow.python.training.session_run_hook import SessionRunArgs +from tensorflow.python.training.summary_io import SummaryWriterCache + + +class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): + """Saves checkpoints every N steps or seconds.""" + + def __init__(self, + checkpoint_dir, + save_secs=None, + save_steps=None, + saver=None, + checkpoint_basename="model.ckpt", + scaffold=None, + listeners=None): + """Initializes a `CheckpointSaverHook`. + + Args: + checkpoint_dir: `str`, base directory for the checkpoint files. + save_secs: `int`, save every N secs. + save_steps: `int`, save every N steps. + saver: `Saver` object, used for saving. + checkpoint_basename: `str`, base name for the checkpoint files. + scaffold: `Scaffold`, use to get saver object. + listeners: List of `CheckpointSaverListener` subclass instances. Used for + callbacks that run immediately before or after this hook saves the + checkpoint. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of `saver` or `scaffold` should be set. + """ + logging.info("Create CheckpointSaverHook.") + if saver is not None and scaffold is not None: + raise ValueError("You cannot provide both saver and scaffold.") + self._saver = saver + self._save_thread = None + self._checkpoint_dir = checkpoint_dir + self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) + self._scaffold = scaffold + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=save_secs, every_steps=save_steps) + self._listeners = listeners or [] + self._steps_per_run = 1 + self._summary_writer = None + self._global_step_tensor = None + + def _set_steps_per_run(self, steps_per_run): + self._steps_per_run = steps_per_run + + def begin(self): + self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) + self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access + if self._global_step_tensor is None: + raise RuntimeError( + "Global step should be created to use CheckpointSaverHook.") + for l in self._listeners: + l.begin() + + def after_create_session(self, session, coord): + global_step = session.run(self._global_step_tensor) + + # We do write graph and saver_def at the first call of before_run. + # We cannot do this in begin, since we let other hooks to change graph and + # add variables in begin. Graph is finalized after all begin calls. + training_util.write_graph( + ops.get_default_graph().as_graph_def(add_shapes=True), + self._checkpoint_dir, "graph.pbtxt") + saver_def = self._get_saver().saver_def if self._get_saver() else None + graph = ops.get_default_graph() + meta_graph_def = meta_graph.create_meta_graph_def( + graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) + self._summary_writer.add_graph(graph) + self._summary_writer.add_meta_graph(meta_graph_def) + # The checkpoint saved here is the state at step "global_step". + self._save(session, global_step) + self._timer.update_last_triggered_step(global_step) + + def before_run(self, run_context): # pylint: disable=unused-argument + return SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + stale_global_step = run_values.results + if self._timer.should_trigger_for_step(stale_global_step + + self._steps_per_run): + # get the real value after train op. + global_step = run_context.session.run(self._global_step_tensor) + if self._timer.should_trigger_for_step(global_step): + self._timer.update_last_triggered_step(global_step) + if self._save(run_context.session, global_step): + run_context.request_stop() + + def end(self, session): + if self._save_thread: + logging.info("Waiting for any pending checkpoints to finish.") + self._save_thread.join() + + last_step = session.run(self._global_step_tensor) + + # Save the last checkpoint synchronously if needed. + if last_step != self._timer.last_triggered_step(): + self._save(session, last_step, asynchronous=False) + + for l in self._listeners: + l.end(session, last_step) + + def _save(self, session, step, asynchronous=True): + """Saves the latest checkpoint, returns should_stop.""" + + def _save_fn(): + """Run the saver process.""" + logging.info("Saving checkpoints for %d into %s.", step, self._save_path) + + start_time = time.time() + for l in self._listeners: + l.before_save(session, step) + + self._get_saver().save(session, self._save_path, global_step=step) + self._summary_writer.add_session_log( + SessionLog( + status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), + step) + end_time = time.time() + logging.info("Checkpoint actual writing time: (%.3f sec)", + end_time - start_time) + logging.info("Checkpoint finished for %d into %s.", step, self._save_path) + + logging.info("Saving checkpoints for %d into %s.", step, self._save_path) + for l in self._listeners: + l.before_save(session, step) + + if not asynchronous: + _save_fn() + return + + if self._save_thread is not None: + self._save_thread.join(timeout=0.1) + if self._save_thread.is_alive(): + logging.info("Saver thread still in progress, skipping checkpoint.") + return + + self._save_thread = threading.Thread(target=_save_fn) + self._save_thread.start() + + def _get_saver(self): + if self._saver is not None: + return self._saver + elif self._scaffold is not None: + return self._scaffold.saver + + # Get saver from the SAVERS collection if present. + collection_key = ops.GraphKeys.SAVERS + savers = ops.get_collection(collection_key) + if not savers: + raise RuntimeError( + "No items in collection {}. Please add a saver to the collection " + "or provide a saver or scaffold.".format(collection_key)) + elif len(savers) > 1: + raise RuntimeError( + "More than one item in collection {}. " + "Please indicate which one to use by passing it to the constructor." + .format(collection_key)) + + self._saver = savers[0] + return savers[0] -- cgit v1.2.3