diff options
author | 2016-05-03 12:05:21 -0800 | |
---|---|---|
committer | 2016-05-03 13:11:44 -0700 | |
commit | b6a7b7dc5ed5e80051fe005445fa4663f67045d7 (patch) | |
tree | 8e1600a33ab5b6f470d0db928d203a1b2b7097d8 /tensorflow/python/training/session_manager_test.py | |
parent | 768b499811932904b897e4d00bee7ff3853416f5 (diff) |
Added a new op report_uninitialized_variables() which returns a 1-D tensor
containing names of the uninitialized variables when run. Supervisor's
ready_op is now implemented with report_uninitialized_variables().
If you write custom ready_op, please make sure
1. If the model is ready, it returns an empty tensor.
2. If the model is not ready, it returns a 1-D tensor containing reasons why
the model is not ready.
assert_variables_initialized() will be kept for 6 months for backward compatibility.
Change: 121407132
Diffstat (limited to 'tensorflow/python/training/session_manager_test.py')
-rw-r--r-- | tensorflow/python/training/session_manager_test.py | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py index 9308548dd2..01cc4f11d6 100644 --- a/tensorflow/python/training/session_manager_test.py +++ b/tensorflow/python/training/session_manager_test.py @@ -31,6 +31,127 @@ class SessionManagerTest(tf.test.TestCase): def testPrepareSessionSucceeds(self): with tf.Graph().as_default(): v = tf.Variable([1.0, 2.0, 3.0], name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + sess = sm.prepare_session("", init_op=tf.initialize_all_variables()) + self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) + + def testPrepareSessionSucceedsWithInitFeedDict(self): + with tf.Graph().as_default(): + p = tf.placeholder(tf.float32, shape=(3,)) + v = tf.Variable(p, name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + sess = sm.prepare_session("", + init_op=tf.initialize_all_variables(), + init_feed_dict={p: [1.0, 2.0, 3.0]}) + self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) + + def testPrepareSessionSucceedsWithInitFn(self): + with tf.Graph().as_default(): + v = tf.Variable([125], name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + sess = sm.prepare_session("", + init_fn=lambda sess: sess.run(v.initializer)) + self.assertAllClose([125], sess.run(v)) + + def testPrepareSessionFails(self): + checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") + checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2") + try: + gfile.DeleteRecursively(checkpoint_dir) + gfile.DeleteRecursively(checkpoint_dir2) + except OSError: + pass # Ignore + gfile.MakeDirs(checkpoint_dir) + + with tf.Graph().as_default(): + v = tf.Variable([1.0, 2.0, 3.0], name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + sess = sm.prepare_session("", init_op=tf.initialize_all_variables(), + saver=saver, checkpoint_dir=checkpoint_dir) + self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) + checkpoint_filename = os.path.join(checkpoint_dir, + "prepare_session_checkpoint") + saver.save(sess, checkpoint_filename) + # Create a new Graph and SessionManager and recover. + with tf.Graph().as_default(): + # Renames the checkpoint directory. + os.rename(checkpoint_dir, checkpoint_dir2) + gfile.MakeDirs(checkpoint_dir) + v = tf.Variable([6.0, 7.0, 8.0], name="v") + with self.test_session(): + self.assertEqual(False, tf.is_variable_initialized(v).eval()) + tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + # This should fail as there's no checkpoint within 2 seconds. + with self.assertRaisesRegexp(RuntimeError, + "no init_op or init_fn was given"): + sess = sm.prepare_session("", init_op=None, saver=saver, + checkpoint_dir=checkpoint_dir, + wait_for_checkpoint=True, max_wait_secs=2) + # Rename the checkpoint directory back. + gfile.DeleteRecursively(checkpoint_dir) + os.rename(checkpoint_dir2, checkpoint_dir) + # This should succeed as there's checkpoint. + sess = sm.prepare_session("", init_op=None, saver=saver, + checkpoint_dir=checkpoint_dir, + wait_for_checkpoint=True, max_wait_secs=2) + self.assertEqual( + True, tf.is_variable_initialized( + sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) + + def testRecoverSession(self): + # Create a checkpoint. + checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session") + try: + gfile.DeleteRecursively(checkpoint_dir) + except OSError: + pass # Ignore + gfile.MakeDirs(checkpoint_dir) + + with tf.Graph().as_default(): + v = tf.Variable(1, name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + sess, initialized = sm.recover_session("", saver=saver, + checkpoint_dir=checkpoint_dir) + self.assertFalse(initialized) + sess.run(v.initializer) + self.assertEquals(1, sess.run(v)) + saver.save(sess, os.path.join(checkpoint_dir, + "recover_session_checkpoint")) + # Create a new Graph and SessionManager and recover. + with tf.Graph().as_default(): + v = tf.Variable(2, name="v") + with self.test_session(): + self.assertEqual(False, tf.is_variable_initialized(v).eval()) + sm2 = tf.train.SessionManager( + ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + sess, initialized = sm2.recover_session("", saver=saver, + checkpoint_dir=checkpoint_dir) + self.assertTrue(initialized) + self.assertEqual( + True, tf.is_variable_initialized( + sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) + self.assertEquals(1, sess.run(v)) + + def testWaitForSessionReturnsNoneAfterTimeout(self): + with tf.Graph().as_default(): + tf.Variable(1, name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables(), + recovery_wait_secs=1) + + # Set max_wait_secs to allow us to try a few times. + with self.assertRaises(errors.DeadlineExceededError): + sm.wait_for_session(master="", max_wait_secs=3) + + +class ObsoleteSessionManagerTest(tf.test.TestCase): + + def testPrepareSessionSucceeds(self): + with tf.Graph().as_default(): + v = tf.Variable([1.0, 2.0, 3.0], name="v") sm = tf.train.SessionManager(ready_op=tf.assert_variables_initialized()) sess = sm.prepare_session("", init_op=tf.initialize_all_variables()) self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) @@ -145,5 +266,6 @@ class SessionManagerTest(tf.test.TestCase): with self.assertRaises(errors.DeadlineExceededError): sm.wait_for_session(master="", max_wait_secs=3) + if __name__ == "__main__": tf.test.main() |