aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/hooks.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/hooks.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index caadafdfa6..faefda7c48 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+import time
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.framework import ops
@@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training
+from tensorflow.python.training import training_util
# pylint: disable=protected-access
@@ -210,4 +212,55 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
self._evaluate(session)
+class StopAtCheckpointStepHook(training.SessionRunHook):
+ """Hook that requests stop at a specified step based on checkpoint."""
+
+ def __init__(self, model_dir, last_step,
+ wait_after_file_check_secs=30):
+ """Initializes a `StopAtCheckpointStepHook`.
+
+ This hook requests stop after a last step has been reached. It checks latest
+ checkpoint to verify last step is written on disk or not.
+
+ Args:
+ model_dir: Directory to read global step from latest checkpoint.
+ last_step: Step after which to stop.
+ wait_after_file_check_secs: Reading same file by many workers may create
+ I/O issues. To throttle that we will wait given secs after each read of
+ the file.
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+ """
+ if last_step is None:
+ raise ValueError('last_step must be specified.')
+ if model_dir is None:
+ raise ValueError('model_dir must be specified.')
+
+ self._model_dir = model_dir
+ self._last_step = last_step
+ self._wait_after_file_check_secs = wait_after_file_check_secs
+
+ def begin(self):
+ self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ if self._global_step_tensor is None:
+ raise RuntimeError(
+ 'Global step should be created to use StopAtCheckpointStepHook.')
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return training.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ global_step = run_values.results + 1
+ if global_step >= self._last_step:
+ # Check latest global step in the checkpoint to ensure that the targeted
+ # last step is written on disk.
+
+ step = estimator_lib._load_global_step_from_checkpoint_dir(
+ self._model_dir)
+ if step >= self._last_step:
+ run_context.request_stop()
+ else:
+ time.sleep(self._wait_after_file_check_secs)
+
# pylint: enable=protected-access