aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r--tensorflow/python/training/monitored_session_test.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 159b2d5c16..3806056f01 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -282,6 +282,42 @@ class MonitoredTrainingSessionTest(test.TestCase):
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(2, session.run(gstep))
+ def test_save_checkpoint_steps(self):
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True,
+ checkpoint_dir=logdir,
+ save_checkpoint_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(100):
+ session.run(new_gstep)
+ # A restart will find the checkpoint and recover automatically.
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True, checkpoint_dir=logdir) as session:
+ self.assertEqual(100, session.run(gstep))
+
+ def test_save_checkpoint_secs(self):
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True,
+ checkpoint_dir=logdir,
+ save_checkpoint_secs=0.1,
+ log_step_count_steps=10) as session:
+ session.run(new_gstep)
+ time.sleep(0.2)
+ for _ in range(10):
+ session.run(new_gstep)
+ # A restart will find the checkpoint and recover automatically.
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True, checkpoint_dir=logdir) as session:
+ self.assertEqual(11, session.run(gstep))
+
def test_summaries_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps')
with ops.Graph().as_default():