aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session_test.py
diff options
context:
space:
mode:
authorGravatar Jerry Liu <twairball@yahoo.com>2018-03-27 07:21:54 +0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-03-26 16:21:54 -0700
commite5dcaf921cf9feefd42b2ab176590c696b3b0285 (patch)
tree788e4d481c07e0e9a5f56a09b74b3f0a39f3d249 /tensorflow/python/training/monitored_session_test.py
parent73f40467bde137e2e2b31297b73944cc2830bdb7 (diff)
Fix #15900 (#16154)
- Added `save_checkpoint_steps` attribute to `MonitoredTrainingSession`. If both `save_checkpoint_steps` and `save_checkpoint_secs` are both `None` then default saver is disabled. Default is `save_checkpoint_secs=600` - Added `test_save_checkpoint_steps` - Updated golden file
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r--tensorflow/python/training/monitored_session_test.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 159b2d5c16..3806056f01 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -282,6 +282,42 @@ class MonitoredTrainingSessionTest(test.TestCase):
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(2, session.run(gstep))
+ def test_save_checkpoint_steps(self):
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True,
+ checkpoint_dir=logdir,
+ save_checkpoint_steps=100,
+ log_step_count_steps=10) as session:
+ for _ in range(100):
+ session.run(new_gstep)
+ # A restart will find the checkpoint and recover automatically.
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True, checkpoint_dir=logdir) as session:
+ self.assertEqual(100, session.run(gstep))
+
+ def test_save_checkpoint_secs(self):
+ logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs')
+ with ops.Graph().as_default():
+ gstep = variables_lib.get_or_create_global_step()
+ new_gstep = state_ops.assign_add(gstep, 1)
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True,
+ checkpoint_dir=logdir,
+ save_checkpoint_secs=0.1,
+ log_step_count_steps=10) as session:
+ session.run(new_gstep)
+ time.sleep(0.2)
+ for _ in range(10):
+ session.run(new_gstep)
+ # A restart will find the checkpoint and recover automatically.
+ with monitored_session.MonitoredTrainingSession(
+ is_chief=True, checkpoint_dir=logdir) as session:
+ self.assertEqual(11, session.run(gstep))
+
def test_summaries_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps')
with ops.Graph().as_default():