aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2018-09-28 16:41:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 16:45:25 -0700
commit541677bfee008a093daab2d033bd72650d886126 (patch)
tree0b8373dfac0b18ac70f29f78e379ceeb8b185015 /tensorflow/contrib/tpu
parent0a341bbcb35d72d14bfda17f7f0cb0c61f323bce (diff)
Add option to disable initialization/shutdown of the TPU.
PiperOrigin-RevId: 215016286
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/__init__.py3
-rw-r--r--tensorflow/contrib/tpu/python/tpu/async_checkpoint.py12
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py9
3 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 766466968a..6ce6b779a2 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -55,7 +55,9 @@
@@TPUDistributionStrategy
@@keras_to_tpu_model
+
@@AsyncCheckpointSaverHook
+@@TPUInMemoryEvalHook
"""
from __future__ import absolute_import
@@ -65,6 +67,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
+from tensorflow.contrib.tpu.python.tpu.async_checkpoint import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
index e06a720e82..20b7ba0997 100644
--- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
+++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
-
"""Hook for asynchronous checkpointing.
This hook dispatches checkpoint writing operations in a separate thread to
@@ -28,18 +27,16 @@ import threading
import time
from tensorflow.core.util.event_pb2 import SessionLog
-
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
-class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
+class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
@@ -67,7 +64,7 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of `saver` or `scaffold` should be set.
"""
- logging.info("Create CheckpointSaverHook.")
+ logging.info("Create AsyncCheckpointSaverHook.")
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
@@ -144,6 +141,10 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
+ # Skip saving on step 0
+ if step == 0:
+ return
+
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
@@ -162,7 +163,6 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
- logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 764d85877a..545cee637f 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -404,12 +404,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._feed_error = None
self._finished = False
+ self._should_initialize_tpu = True
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_ops = [tpu.initialize_system(job=self._master_job)]
- self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ if self._should_initialize_tpu:
+ self._init_ops = [tpu.initialize_system(job=self._master_job)]
+ self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ else:
+ self._init_ops = []
+ self._finalize_ops = []
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
self._init_ops.extend(summary_writer_init_ops)