diff options
author | Mustafa Ispir <ispir@google.com> | 2018-08-16 14:45:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 14:56:33 -0700 |
commit | 1c7cafc092cf489b77f338478be5bde1721c5d71 (patch) | |
tree | 618393bb82adf2cf126f3bbd98ddd8954d0c48de /tensorflow/contrib/estimator | |
parent | 2a08dd7f020138f5b79af188504937321f4d542d (diff) |
Added a factory for StopAtCheckpointStepHook. Chief is responsible to save the checkpoint. Therefore StopAtCheckpointStepHook should not be used in chief mode.
PiperOrigin-RevId: 209051351
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/hooks.py | 21 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/hooks_test.py | 32 |
3 files changed, 49 insertions, 6 deletions
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 6ad3a4a604..258860f263 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -45,7 +45,7 @@ _allowed_symbols = [ 'clip_gradients_by_norm', 'forward_features', 'InMemoryEvaluatorHook', - 'StopAtCheckpointStepHook', + 'make_stop_at_checkpoint_step_hook', 'logistic_regression_head', 'multi_class_head', 'multi_head', diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py index 30455314a9..66c46e66b7 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -213,8 +213,12 @@ class InMemoryEvaluatorHook(training.SessionRunHook): self._evaluate(session) -class StopAtCheckpointStepHook(training.SessionRunHook): - """Hook that requests stop at a specified step based on checkpoint.""" +class _StopAtCheckpointStepHook(training.SessionRunHook): + """Hook that requests stop at a specified step based on checkpoint. + + Note: We recommend using 'make_stop_at_checkpoint_step_hook` to get the proper + hook. + """ def __init__(self, model_dir, last_step, wait_after_file_check_secs=30): @@ -264,4 +268,17 @@ class StopAtCheckpointStepHook(training.SessionRunHook): else: time.sleep(self._wait_after_file_check_secs) + +def make_stop_at_checkpoint_step_hook(estimator, + last_step, + wait_after_file_check_secs=30): + """Creates a proper StopAtCheckpointStepHook based on chief status.""" + + if estimator.config.is_chief: + return training.StopAtStepHook(last_step=last_step) + return _StopAtCheckpointStepHook( + model_dir=estimator.model_dir, + last_step=last_step, + wait_after_file_check_secs=wait_after_file_check_secs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py index 42352aa3ff..c6c6cad95a 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -326,7 +326,7 @@ class StopAtCheckpointStepHookTest(test.TestCase): step = training.create_global_step() assign_ten = step.assign(10) no_op = control_flow_ops.no_op() - hook = hooks_lib.StopAtCheckpointStepHook( + hook = hooks_lib._StopAtCheckpointStepHook( model_dir=tempfile.mkdtemp(), last_step=10) with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: mon_sess.raw_session().run(assign_ten) @@ -342,7 +342,7 @@ class StopAtCheckpointStepHookTest(test.TestCase): assign_nine = step.assign(9) assign_ten = step.assign(10) no_op = control_flow_ops.no_op() - hook = hooks_lib.StopAtCheckpointStepHook( + hook = hooks_lib._StopAtCheckpointStepHook( model_dir=model_dir, last_step=10) with tf_session.Session() as sess: sess.run(assign_nine) @@ -360,7 +360,7 @@ class StopAtCheckpointStepHookTest(test.TestCase): step = training.create_global_step() assign_ten = step.assign(10) no_op = control_flow_ops.no_op() - hook = hooks_lib.StopAtCheckpointStepHook( + hook = hooks_lib._StopAtCheckpointStepHook( model_dir=model_dir, last_step=10) with tf_session.Session() as sess: sess.run(assign_ten) @@ -372,6 +372,32 @@ class StopAtCheckpointStepHookTest(test.TestCase): self.assertFalse(mock_sleep.called) self.assertTrue(mon_sess.should_stop()) + def test_creates_regular_stop_at_step_hook_for_chief(self): + # by default an estimator is in chief mode + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300) + self.assertIsInstance(hook, training.StopAtStepHook) + self.assertEqual(300, hook._last_step) + + def test_creates_checkpoint_hook_for_workers(self): + + class FakeWorkerConfig(estimator_lib.RunConfig): + + @property + def is_chief(self): + return False + + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1], + config=FakeWorkerConfig()) + hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300) + self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook) + self.assertEqual(300, hook._last_step) + self.assertEqual(dnn.model_dir, hook._model_dir) + if __name__ == '__main__': test.main() |