aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/hooks_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/hooks_test.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index ee88d5ecf5..42352aa3ff 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -21,8 +21,11 @@ from __future__ import print_function
import glob
import json
import os
+import tempfile
+import time
from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib
+from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator import run_config as run_config_lib
@@ -316,5 +319,59 @@ class InMemoryEvaluatorHookTest(test.TestCase):
estimator.train(input_fn, hooks=[evaluator])
+class StopAtCheckpointStepHookTest(test.TestCase):
+
+ def test_do_not_stop_if_checkpoint_is_not_there(self):
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=tempfile.mkdtemp(), last_step=10)
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_do_not_stop_if_checkpoint_step_is_smaller(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_nine = step.assign(9)
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_nine)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_stop_if_checkpoint_step_is_laststep(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib.StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_ten)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertFalse(mock_sleep.called)
+ self.assertTrue(mon_sess.should_stop())
+
+
if __name__ == '__main__':
test.main()