diff options
author | 2016-11-11 16:53:17 -0800 | |
---|---|---|
committer | 2016-11-11 17:03:52 -0800 | |
commit | 97770c050315370f17edd5c5933a2367a560dcf3 (patch) | |
tree | d2fee5145e7f224558d9bf133c03c570e3cad386 /tensorflow/python/training/monitored_session_test.py | |
parent | 2ff6ffac1234e11e5cbfd2c96f207407889c707a (diff) |
Extending MonitoredTrainingSession to make model saving and chief operations more flexible.
Change: 138934639
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r-- | tensorflow/python/training/monitored_session_test.py | 70 |
1 files changed, 49 insertions, 21 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 2fde009ce3..bd8cdad628 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -140,6 +140,33 @@ def _test_dir(temp_dir, test_name): return test_dir +class FakeHook(tf.train.SessionRunHook): + + def __init__(self): + self.should_stop = False + self.request = None + self.call_counter = Counter() + self.last_run_context = None + self.last_run_values = None + + def begin(self): + self.call_counter['begin'] += 1 + + def before_run(self, run_context): + self.call_counter['before_run'] += 1 + self.last_run_context = run_context + return self.request + + def after_run(self, run_context, run_values): + self.call_counter['after_run'] += 1 + self.last_run_values = run_values + if self.should_stop: + run_context.request_stop() + + def end(self, session): + self.call_counter['end'] += 1 + + class MonitoredTrainingSessionTest(tf.test.TestCase): """Tests MonitoredTrainingSession.""" @@ -173,6 +200,28 @@ class MonitoredTrainingSessionTest(tf.test.TestCase): self.assertIn('my_summary_tag', tags) self.assertIn('global_step/sec', tags) + def test_custom_saving(self): + logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint') + fake_hook = FakeHook() + with tf.Graph().as_default(): + gstep = tf.contrib.framework.get_or_create_global_step() + do_step = tf.assign_add(gstep, 1) + with tf.train.MonitoredTrainingSession( + is_chief=True, + checkpoint_dir=logdir, + chief_only_hooks=[fake_hook], + save_checkpoint_secs=0) as session: + self.assertEqual(0, session.run(gstep)) + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + + # Check whether custom hook called or not + self.assertEqual(1, fake_hook.call_counter['begin']) + # A restart will not find the checkpoint, since we didn't save. + with tf.train.MonitoredTrainingSession( + is_chief=True, checkpoint_dir=logdir) as session: + self.assertEqual(0, session.run(gstep)) + class StopAtNSession(monitored_session._WrappedSession): """A wrapped session that stops at the N-th call to _check_stop.""" @@ -441,27 +490,6 @@ class FakeSession(monitored_session._WrappedSession): return monitored_session._WrappedSession.run(self, fetches) -class FakeHook(tf.train.SessionRunHook): - - def __init__(self): - self.should_stop = False - self.request = None - self.call_counter = Counter() - self.last_run_context = None - self.last_run_values = None - - def before_run(self, run_context): - self.call_counter['before_run'] += 1 - self.last_run_context = run_context - return self.request - - def after_run(self, run_context, run_values): - self.call_counter['after_run'] += 1 - self.last_run_values = run_values - if self.should_stop: - run_context.request_stop() - - class HookedSessionTest(tf.test.TestCase): def testRunPassesAllArguments(self): |