aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2018-08-14 10:57:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 11:02:32 -0700
commit65c2daaff0a11e257595dcff482868dd709cd931 (patch)
tree6f7c8d0c19f55476d1f37952cfad39b30476543c /tensorflow/contrib/estimator
parent10a7b245187ee4b85fad70a4b2f66f7aa997a401 (diff)
Provide a stopper hook which checks latest checkpoint. This hook will be helpful to relieve following edge case:
* global_step is reached to last_step * all workers stop due to last_step check except chief * chief starts writing the last checkpoint * A PS is preempted while chief is writing the checkpoint * chief restarts training from an older checkpoint * at this point only chief remains to handle remaining global steps. PiperOrigin-RevId: 208675370
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/__init__.py1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py53
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py57
3 files changed, 111 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index e1453ae1d0..6ad3a4a604 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -45,6 +45,7 @@ _allowed_symbols = [
'clip_gradients_by_norm',
'forward_features',
'InMemoryEvaluatorHook',
+ 'StopAtCheckpointStepHook',
'logistic_regression_head',
'multi_class_head',
'multi_head',
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index caadafdfa6..faefda7c48 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+import time
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.framework import ops
@@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training
+from tensorflow.python.training import training_util
# pylint: disable=protected-access
@@ -210,4 +212,55 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
self._evaluate(session)
+class StopAtCheckpointStepHook(training.SessionRunHook):
+ """Hook that requests stop at a specified step based on checkpoint."""
+
+ def __init__(self, model_dir, last_step,
+ wait_after_file_check_secs=30):
+ """Initializes a `StopAtCheckpointStepHook`.
+
+ This hook requests stop after a last step has been reached. It checks latest
+ checkpoint to verify last step is written on disk or not.
+
+ Args:
+ model_dir: Directory to read global step from latest checkpoint.
+ last_step: Step after which to stop.
+ wait_after_file_check_secs: Reading same file by many workers may create
+ I/O issues. To throttle that we will wait given secs after each read of
+ the file.
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+ """
+ if last_step is None:
+ raise ValueError('last_step must be specified.')
+ if model_dir is None:
+ raise ValueError('model_dir must be specified.')
+
+ self._model_dir = model_dir
+ self._last_step = last_step
+ self._wait_after_file_check_secs = wait_after_file_check_secs
+
+ def begin(self):
+ self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ if self._global_step_tensor is None:
+ raise RuntimeError(
+ 'Global step should be created to use StopAtCheckpointStepHook.')
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return training.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ global_step = run_values.results + 1
+ if global_step >= self._last_step:
+ # Check latest global step in the checkpoint to ensure that the targeted
+ # last step is written on disk.
+
+ step = estimator_lib._load_global_step_from_checkpoint_dir(
+ self._model_dir)
+ if step >= self._last_step:
+ run_context.request_stop()
+ else:
+ time.sleep(self._wait_after_file_check_secs)
+
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index ee88d5ecf5..42352aa3ff 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -21,8 +21,11 @@ from __future__ import print_function
import glob
import json
import os
+import tempfile
+import time
from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib
+from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator import run_config as run_config_lib
@@ -316,5 +319,59 @@ class InMemoryEvaluatorHookTest(test.TestCase):
estimator.train(input_fn, hooks=[evaluator])
+class StopAtCheckpointStepHookTest(test.TestCase):
+
+ def test_do_not_stop_if_checkpoint_is_not_there(self):
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=tempfile.mkdtemp(), last_step=10)
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_do_not_stop_if_checkpoint_step_is_smaller(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_nine = step.assign(9)
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_nine)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_stop_if_checkpoint_step_is_laststep(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_ten)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertFalse(mock_sleep.called)
+ self.assertTrue(mon_sess.should_stop())
+
+
if __name__ == '__main__':
test.main()