aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/session_manager_test.py
diff options
context:
space:
mode:
authorGravatar Sherry Moore <sherrym@google.com>2016-05-03 12:05:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-03 13:11:44 -0700
commitb6a7b7dc5ed5e80051fe005445fa4663f67045d7 (patch)
tree8e1600a33ab5b6f470d0db928d203a1b2b7097d8 /tensorflow/python/training/session_manager_test.py
parent768b499811932904b897e4d00bee7ff3853416f5 (diff)
Added a new op report_uninitialized_variables() which returns a 1-D tensor
containing names of the uninitialized variables when run. Supervisor's ready_op is now implemented with report_uninitialized_variables(). If you write custom ready_op, please make sure 1. If the model is ready, it returns an empty tensor. 2. If the model is not ready, it returns a 1-D tensor containing reasons why the model is not ready. assert_variables_initialized() will be kept for 6 months for backward compatibility. Change: 121407132
Diffstat (limited to 'tensorflow/python/training/session_manager_test.py')
-rw-r--r--tensorflow/python/training/session_manager_test.py122
1 files changed, 122 insertions, 0 deletions
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index 9308548dd2..01cc4f11d6 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -31,6 +31,127 @@ class SessionManagerTest(tf.test.TestCase):
def testPrepareSessionSucceeds(self):
with tf.Graph().as_default():
v = tf.Variable([1.0, 2.0, 3.0], name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ sess = sm.prepare_session("", init_op=tf.initialize_all_variables())
+ self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
+
+ def testPrepareSessionSucceedsWithInitFeedDict(self):
+ with tf.Graph().as_default():
+ p = tf.placeholder(tf.float32, shape=(3,))
+ v = tf.Variable(p, name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ sess = sm.prepare_session("",
+ init_op=tf.initialize_all_variables(),
+ init_feed_dict={p: [1.0, 2.0, 3.0]})
+ self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
+
+ def testPrepareSessionSucceedsWithInitFn(self):
+ with tf.Graph().as_default():
+ v = tf.Variable([125], name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ sess = sm.prepare_session("",
+ init_fn=lambda sess: sess.run(v.initializer))
+ self.assertAllClose([125], sess.run(v))
+
+ def testPrepareSessionFails(self):
+ checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
+ checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
+ try:
+ gfile.DeleteRecursively(checkpoint_dir)
+ gfile.DeleteRecursively(checkpoint_dir2)
+ except OSError:
+ pass # Ignore
+ gfile.MakeDirs(checkpoint_dir)
+
+ with tf.Graph().as_default():
+ v = tf.Variable([1.0, 2.0, 3.0], name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ sess = sm.prepare_session("", init_op=tf.initialize_all_variables(),
+ saver=saver, checkpoint_dir=checkpoint_dir)
+ self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
+ checkpoint_filename = os.path.join(checkpoint_dir,
+ "prepare_session_checkpoint")
+ saver.save(sess, checkpoint_filename)
+ # Create a new Graph and SessionManager and recover.
+ with tf.Graph().as_default():
+ # Renames the checkpoint directory.
+ os.rename(checkpoint_dir, checkpoint_dir2)
+ gfile.MakeDirs(checkpoint_dir)
+ v = tf.Variable([6.0, 7.0, 8.0], name="v")
+ with self.test_session():
+ self.assertEqual(False, tf.is_variable_initialized(v).eval())
+ tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ # This should fail as there's no checkpoint within 2 seconds.
+ with self.assertRaisesRegexp(RuntimeError,
+ "no init_op or init_fn was given"):
+ sess = sm.prepare_session("", init_op=None, saver=saver,
+ checkpoint_dir=checkpoint_dir,
+ wait_for_checkpoint=True, max_wait_secs=2)
+ # Rename the checkpoint directory back.
+ gfile.DeleteRecursively(checkpoint_dir)
+ os.rename(checkpoint_dir2, checkpoint_dir)
+ # This should succeed as there's checkpoint.
+ sess = sm.prepare_session("", init_op=None, saver=saver,
+ checkpoint_dir=checkpoint_dir,
+ wait_for_checkpoint=True, max_wait_secs=2)
+ self.assertEqual(
+ True, tf.is_variable_initialized(
+ sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+
+ def testRecoverSession(self):
+ # Create a checkpoint.
+ checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
+ try:
+ gfile.DeleteRecursively(checkpoint_dir)
+ except OSError:
+ pass # Ignore
+ gfile.MakeDirs(checkpoint_dir)
+
+ with tf.Graph().as_default():
+ v = tf.Variable(1, name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ sess, initialized = sm.recover_session("", saver=saver,
+ checkpoint_dir=checkpoint_dir)
+ self.assertFalse(initialized)
+ sess.run(v.initializer)
+ self.assertEquals(1, sess.run(v))
+ saver.save(sess, os.path.join(checkpoint_dir,
+ "recover_session_checkpoint"))
+ # Create a new Graph and SessionManager and recover.
+ with tf.Graph().as_default():
+ v = tf.Variable(2, name="v")
+ with self.test_session():
+ self.assertEqual(False, tf.is_variable_initialized(v).eval())
+ sm2 = tf.train.SessionManager(
+ ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ sess, initialized = sm2.recover_session("", saver=saver,
+ checkpoint_dir=checkpoint_dir)
+ self.assertTrue(initialized)
+ self.assertEqual(
+ True, tf.is_variable_initialized(
+ sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+ self.assertEquals(1, sess.run(v))
+
+ def testWaitForSessionReturnsNoneAfterTimeout(self):
+ with tf.Graph().as_default():
+ tf.Variable(1, name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables(),
+ recovery_wait_secs=1)
+
+ # Set max_wait_secs to allow us to try a few times.
+ with self.assertRaises(errors.DeadlineExceededError):
+ sm.wait_for_session(master="", max_wait_secs=3)
+
+
+class ObsoleteSessionManagerTest(tf.test.TestCase):
+
+ def testPrepareSessionSucceeds(self):
+ with tf.Graph().as_default():
+ v = tf.Variable([1.0, 2.0, 3.0], name="v")
sm = tf.train.SessionManager(ready_op=tf.assert_variables_initialized())
sess = sm.prepare_session("", init_op=tf.initialize_all_variables())
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@@ -145,5 +266,6 @@ class SessionManagerTest(tf.test.TestCase):
with self.assertRaises(errors.DeadlineExceededError):
sm.wait_for_session(master="", max_wait_secs=3)
+
if __name__ == "__main__":
tf.test.main()