aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/monitored_session_test.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-11-11 16:53:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 17:03:52 -0800
commit97770c050315370f17edd5c5933a2367a560dcf3 (patch)
treed2fee5145e7f224558d9bf133c03c570e3cad386 /tensorflow/python/training/monitored_session_test.py
parent2ff6ffac1234e11e5cbfd2c96f207407889c707a (diff)
Extending MonitoredTrainingSession to make model saving and chief operations more flexible.
Change: 138934639
Diffstat (limited to 'tensorflow/python/training/monitored_session_test.py')
-rw-r--r--tensorflow/python/training/monitored_session_test.py70
1 files changed, 49 insertions, 21 deletions
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 2fde009ce3..bd8cdad628 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -140,6 +140,33 @@ def _test_dir(temp_dir, test_name):
return test_dir
+class FakeHook(tf.train.SessionRunHook):
+
+ def __init__(self):
+ self.should_stop = False
+ self.request = None
+ self.call_counter = Counter()
+ self.last_run_context = None
+ self.last_run_values = None
+
+ def begin(self):
+ self.call_counter['begin'] += 1
+
+ def before_run(self, run_context):
+ self.call_counter['before_run'] += 1
+ self.last_run_context = run_context
+ return self.request
+
+ def after_run(self, run_context, run_values):
+ self.call_counter['after_run'] += 1
+ self.last_run_values = run_values
+ if self.should_stop:
+ run_context.request_stop()
+
+ def end(self, session):
+ self.call_counter['end'] += 1
+
+
class MonitoredTrainingSessionTest(tf.test.TestCase):
"""Tests MonitoredTrainingSession."""
@@ -173,6 +200,28 @@ class MonitoredTrainingSessionTest(tf.test.TestCase):
self.assertIn('my_summary_tag', tags)
self.assertIn('global_step/sec', tags)
+ def test_custom_saving(self):
+ logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
+ fake_hook = FakeHook()
+ with tf.Graph().as_default():
+ gstep = tf.contrib.framework.get_or_create_global_step()
+ do_step = tf.assign_add(gstep, 1)
+ with tf.train.MonitoredTrainingSession(
+ is_chief=True,
+ checkpoint_dir=logdir,
+ chief_only_hooks=[fake_hook],
+ save_checkpoint_secs=0) as session:
+ self.assertEqual(0, session.run(gstep))
+ self.assertEqual(1, session.run(do_step))
+ self.assertEqual(2, session.run(do_step))
+
+ # Check whether custom hook called or not
+ self.assertEqual(1, fake_hook.call_counter['begin'])
+ # A restart will not find the checkpoint, since we didn't save.
+ with tf.train.MonitoredTrainingSession(
+ is_chief=True, checkpoint_dir=logdir) as session:
+ self.assertEqual(0, session.run(gstep))
+
class StopAtNSession(monitored_session._WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop."""
@@ -441,27 +490,6 @@ class FakeSession(monitored_session._WrappedSession):
return monitored_session._WrappedSession.run(self, fetches)
-class FakeHook(tf.train.SessionRunHook):
-
- def __init__(self):
- self.should_stop = False
- self.request = None
- self.call_counter = Counter()
- self.last_run_context = None
- self.last_run_values = None
-
- def before_run(self, run_context):
- self.call_counter['before_run'] += 1
- self.last_run_context = run_context
- return self.request
-
- def after_run(self, run_context, run_values):
- self.call_counter['after_run'] += 1
- self.last_run_values = run_values
- if self.should_stop:
- run_context.request_stop()
-
-
class HookedSessionTest(tf.test.TestCase):
def testRunPassesAllArguments(self):