diff options
author | Russell Power <power@google.com> | 2018-09-28 16:41:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 16:45:25 -0700 |
commit | 541677bfee008a093daab2d033bd72650d886126 (patch) | |
tree | 0b8373dfac0b18ac70f29f78e379ceeb8b185015 /tensorflow/contrib/tpu | |
parent | 0a341bbcb35d72d14bfda17f7f0cb0c61f323bce (diff) |
Add option to disable initialization/shutdown of the TPU.
PiperOrigin-RevId: 215016286
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/__init__.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/async_checkpoint.py | 12 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 9 |
3 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index 766466968a..6ce6b779a2 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -55,7 +55,9 @@ @@TPUDistributionStrategy @@keras_to_tpu_model + @@AsyncCheckpointSaverHook +@@TPUInMemoryEvalHook """ from __future__ import absolute_import @@ -65,6 +67,7 @@ from __future__ import print_function # pylint: disable=wildcard-import,unused-import from tensorflow.contrib.tpu.python import profiler from tensorflow.contrib.tpu.python.ops.tpu_ops import * +from tensorflow.contrib.tpu.python.tpu.async_checkpoint import * from tensorflow.contrib.tpu.python.tpu.bfloat16 import * from tensorflow.contrib.tpu.python.tpu.device_assignment import * from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index e06a720e82..20b7ba0997 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ====================================== - """Hook for asynchronous checkpointing. This hook dispatches checkpoint writing operations in a separate thread to @@ -28,18 +27,16 @@ import threading import time from tensorflow.core.util.event_pb2 import SessionLog - from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs from tensorflow.python.training.summary_io import SummaryWriterCache -class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): +class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): """Saves checkpoints every N steps or seconds.""" def __init__(self, @@ -67,7 +64,7 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): ValueError: One of `save_steps` or `save_secs` should be set. ValueError: At most one of `saver` or `scaffold` should be set. """ - logging.info("Create CheckpointSaverHook.") + logging.info("Create AsyncCheckpointSaverHook.") if saver is not None and scaffold is not None: raise ValueError("You cannot provide both saver and scaffold.") self._saver = saver @@ -144,6 +141,10 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): def _save(self, session, step, asynchronous=True): """Saves the latest checkpoint, returns should_stop.""" + # Skip saving on step 0 + if step == 0: + return + def _save_fn(): """Run the saver process.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) @@ -162,7 +163,6 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): end_time - start_time) logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - logging.info("Saving checkpoints for %d into %s.", step, self._save_path) for l in self._listeners: l.before_save(session, step) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 764d85877a..545cee637f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -404,12 +404,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._feed_error = None self._finished = False + self._should_initialize_tpu = True def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [tpu.initialize_system(job=self._master_job)] - self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] + if self._should_initialize_tpu: + self._init_ops = [tpu.initialize_system(job=self._master_job)] + self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] + else: + self._init_ops = [] + self._finalize_ops = [] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() self._init_ops.extend(summary_writer_init_ops) |