aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2018-08-16 14:45:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 14:56:33 -0700
commit1c7cafc092cf489b77f338478be5bde1721c5d71 (patch)
tree618393bb82adf2cf126f3bbd98ddd8954d0c48de /tensorflow/contrib/estimator
parent2a08dd7f020138f5b79af188504937321f4d542d (diff)
Added a factory for StopAtCheckpointStepHook. Chief is responsible to save the checkpoint. Therefore StopAtCheckpointStepHook should not be used in chief mode.
PiperOrigin-RevId: 209051351
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/__init__.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py21
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py32
3 files changed, 49 insertions, 6 deletions
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 6ad3a4a604..258860f263 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -45,7 +45,7 @@ _allowed_symbols = [
'clip_gradients_by_norm',
'forward_features',
'InMemoryEvaluatorHook',
- 'StopAtCheckpointStepHook',
+ 'make_stop_at_checkpoint_step_hook',
'logistic_regression_head',
'multi_class_head',
'multi_head',
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index 30455314a9..66c46e66b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -213,8 +213,12 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
self._evaluate(session)
-class StopAtCheckpointStepHook(training.SessionRunHook):
- """Hook that requests stop at a specified step based on checkpoint."""
+class _StopAtCheckpointStepHook(training.SessionRunHook):
+ """Hook that requests stop at a specified step based on checkpoint.
+
+ Note: We recommend using 'make_stop_at_checkpoint_step_hook` to get the proper
+ hook.
+ """
def __init__(self, model_dir, last_step,
wait_after_file_check_secs=30):
@@ -264,4 +268,17 @@ class StopAtCheckpointStepHook(training.SessionRunHook):
else:
time.sleep(self._wait_after_file_check_secs)
+
+def make_stop_at_checkpoint_step_hook(estimator,
+ last_step,
+ wait_after_file_check_secs=30):
+ """Creates a proper StopAtCheckpointStepHook based on chief status."""
+
+ if estimator.config.is_chief:
+ return training.StopAtStepHook(last_step=last_step)
+ return _StopAtCheckpointStepHook(
+ model_dir=estimator.model_dir,
+ last_step=last_step,
+ wait_after_file_check_secs=wait_after_file_check_secs)
+
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index 42352aa3ff..c6c6cad95a 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -326,7 +326,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
step = training.create_global_step()
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=tempfile.mkdtemp(), last_step=10)
with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
mon_sess.raw_session().run(assign_ten)
@@ -342,7 +342,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
assign_nine = step.assign(9)
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=model_dir, last_step=10)
with tf_session.Session() as sess:
sess.run(assign_nine)
@@ -360,7 +360,7 @@ class StopAtCheckpointStepHookTest(test.TestCase):
step = training.create_global_step()
assign_ten = step.assign(10)
no_op = control_flow_ops.no_op()
- hook = hooks_lib.StopAtCheckpointStepHook(
+ hook = hooks_lib._StopAtCheckpointStepHook(
model_dir=model_dir, last_step=10)
with tf_session.Session() as sess:
sess.run(assign_ten)
@@ -372,6 +372,32 @@ class StopAtCheckpointStepHookTest(test.TestCase):
self.assertFalse(mock_sleep.called)
self.assertTrue(mon_sess.should_stop())
+ def test_creates_regular_stop_at_step_hook_for_chief(self):
+ # by default an estimator is in chief mode
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1])
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, training.StopAtStepHook)
+ self.assertEqual(300, hook._last_step)
+
+ def test_creates_checkpoint_hook_for_workers(self):
+
+ class FakeWorkerConfig(estimator_lib.RunConfig):
+
+ @property
+ def is_chief(self):
+ return False
+
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1],
+ config=FakeWorkerConfig())
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook)
+ self.assertEqual(300, hook._last_step)
+ self.assertEqual(dnn.model_dir, hook._model_dir)
+
if __name__ == '__main__':
test.main()