From b6a7b7dc5ed5e80051fe005445fa4663f67045d7 Mon Sep 17 00:00:00 2001 From: Sherry Moore Date: Tue, 3 May 2016 12:05:21 -0800 Subject: Added a new op report_uninitialized_variables() which returns a 1-D tensor containing names of the uninitialized variables when run. Supervisor's ready_op is now implemented with report_uninitialized_variables(). If you write custom ready_op, please make sure 1. If the model is ready, it returns an empty tensor. 2. If the model is not ready, it returns a 1-D tensor containing reasons why the model is not ready. assert_variables_initialized() will be kept for 6 months for backward compatibility. Change: 121407132 --- tensorflow/python/kernel_tests/variables_test.py | 29 +++++ tensorflow/python/ops/state_ops.py | 1 + tensorflow/python/ops/variables.py | 49 ++++++++- tensorflow/python/training/session_manager.py | 26 ++++- tensorflow/python/training/session_manager_test.py | 122 +++++++++++++++++++++ tensorflow/python/training/supervisor.py | 18 ++- tensorflow/python/training/supervisor_test.py | 6 +- 7 files changed, 232 insertions(+), 19 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 3949c77c2a..ed29409ee3 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -360,6 +360,35 @@ class VariablesTestCase(tf.test.TestCase): class IsInitializedTest(tf.test.TestCase): + def testNoVars(self): + with tf.Graph().as_default(), self.test_session() as sess: + uninited = tf.report_uninitialized_variables() + self.assertEqual(0, sess.run(uninited).size) + + def testAssertVariablesInitialized(self): + with tf.Graph().as_default(), self.test_session() as sess: + v = tf.Variable([1, 2], name="v") + w = tf.Variable([3, 4], name="w") + _ = v, w + uninited = tf.report_uninitialized_variables() + self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited)) + tf.initialize_all_variables().run() + self.assertEqual(0, sess.run(uninited).size) + + def testVariableList(self): + with tf.Graph().as_default(), self.test_session() as sess: + v = tf.Variable([1, 2], name="v") + w = tf.Variable([3, 4], name="w") + uninited = tf.report_uninitialized_variables() + self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited)) + sess.run(w.initializer) + self.assertAllEqual(np.array([b"v"]), sess.run(uninited)) + v.initializer.run() + self.assertEqual(0, sess.run(uninited).size) + + +class ObsoleteIsInitializedTest(tf.test.TestCase): + def testNoVars(self): with tf.Graph().as_default(): self.assertEqual(None, tf.assert_variables_initialized()) diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 78d675aaad..19c2878b1e 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -31,6 +31,7 @@ collected in the graph. @@initialize_variables @@initialize_local_variables @@is_variable_initialized +@@report_uninitialized_variables @@assert_variables_initialized ## Saving and Restoring Variables diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 0f7bac4c86..328fdc6503 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -19,9 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import variable_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -841,13 +843,14 @@ def initialize_local_variables(): def is_variable_initialized(variable): - """Returns an Op to check if a variable has been initialized. + """Tests if a variable has been initialized. Args: variable: A `Variable`. Returns: - An operation to check whether a variable has been initialized. + Returns a scalar boolean Tensor, `True` if the variable has been + initialized, `False` otherwise. """ return state_ops.is_variable_initialized(variable) @@ -855,6 +858,9 @@ def is_variable_initialized(variable): def assert_variables_initialized(var_list=None): """Returns an Op to check if variables are initialized. + NOTE: This function is obsolete and will be removed in 6 months. Please + change your implementation to use `report_uninitialized_variables()`. + When run, the returned Op will raise the exception `FailedPreconditionError` if any of the variables has not yet been initialized. @@ -890,6 +896,45 @@ def assert_variables_initialized(var_list=None): return array_ops.pack(ranks) +def report_uninitialized_variables(var_list=None, + name="report_uninitialized_variables"): + """Adds ops to list the names of uninitialized variables. + + When run, it returns a 1-D tensor containing the names of uninitialized + variables if there are any, or an empty array if there are none. + + Args: + var_list: List of `Variable` objects to check. Defaults to the + value of `all_variables() + local_variables()` + name: Optional name of the `Operation`. + + Returns: + A 1-D tensor containing names of the unintialized variables, or an empty 1-D + tensor if there are no variables or no uninitialized variables. + """ + if var_list is None: + var_list = all_variables() + local_variables() + # Backwards compatibility for old-style variables. TODO(touts): remove. + if not var_list: + var_list = [] + for op in ops.get_default_graph().get_operations(): + if op.type in ["Variable", "AutoReloadVariable"]: + var_list.append(op.outputs[0]) + if not var_list: + # Return an empty tensor so we only need to check for returned tensor + # size being 0 as an indication of model ready. + return array_ops.constant([], dtype=dtypes.string, name=name) + else: + # Get a 1-D boolean tensor listing whether each variable is initialized. + variables_mask = math_ops.logical_not(array_ops.pack( + [state_ops.is_variable_initialized(v) for v in var_list])) + # Get a 1-D string tensor containing all the variable names. + variable_names_tensor = array_ops.constant([s.op.name for s in var_list]) + # Return a 1-D tensor containing all the names of uninitialized variables. + return array_ops.boolean_mask(variable_names_tensor, variables_mask, + name=name) + + # pylint: disable=protected-access ops.register_tensor_conversion_function(Variable, Variable._TensorConversionFunction) diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index f6b88bd872..d418604499 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -19,6 +19,7 @@ from __future__ import print_function import threading import time +import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import errors @@ -84,9 +85,12 @@ class SessionManager(object): The `local_init_op` is an `Operation` that is run always after a new session was created. If `None`, this step is skipped. - The `ready_op` is an `Operation`. The model is considered ready - if that operation succeeds. If `None`, the model is not checked - for readiness. + The `ready_op` is an `Operation` used to check if the model is ready. The + model is considered ready if that operation returns an empty string tensor. + If the operation returns non empty string tensor, the elements are + concatenated and used to indicate to the user why the model is not ready. + + If `ready_op` is `None`, the model is not checked for readiness. `recovery_wait_secs` is the number of seconds between checks that the model is ready. It is used by processes to wait for a model to @@ -325,8 +329,20 @@ class SessionManager(object): return None else: try: - sess.run(self._ready_op) - return None + ready_value = sess.run(self._ready_op) + # The model is considered ready if ready_op returns an empty 1-D tensor. + # Also compare to `None` and dtype being int32 for backward + # compatibility. + if (ready_value is None or ready_value.dtype == np.int32 or + ready_value.size == 0): + return None + else: + # TODO(sherrym): If a custom ready_op returns other types of tensor, + # or strings other than variable names, this message could be + # confusing. + non_initialized_varnames = ", ".join( + [i.decode("utf-8") for i in ready_value]) + return "Variables not initialized: " + non_initialized_varnames except errors.FailedPreconditionError as e: if "uninitialized" not in str(e): logging.warning("Model not ready raised: %s", str(e)) diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py index 9308548dd2..01cc4f11d6 100644 --- a/tensorflow/python/training/session_manager_test.py +++ b/tensorflow/python/training/session_manager_test.py @@ -28,6 +28,127 @@ from tensorflow.python.platform import gfile class SessionManagerTest(tf.test.TestCase): + def testPrepareSessionSucceeds(self): + with tf.Graph().as_default(): + v = tf.Variable([1.0, 2.0, 3.0], name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + sess = sm.prepare_session("", init_op=tf.initialize_all_variables()) + self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) + + def testPrepareSessionSucceedsWithInitFeedDict(self): + with tf.Graph().as_default(): + p = tf.placeholder(tf.float32, shape=(3,)) + v = tf.Variable(p, name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + sess = sm.prepare_session("", + init_op=tf.initialize_all_variables(), + init_feed_dict={p: [1.0, 2.0, 3.0]}) + self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) + + def testPrepareSessionSucceedsWithInitFn(self): + with tf.Graph().as_default(): + v = tf.Variable([125], name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + sess = sm.prepare_session("", + init_fn=lambda sess: sess.run(v.initializer)) + self.assertAllClose([125], sess.run(v)) + + def testPrepareSessionFails(self): + checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") + checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2") + try: + gfile.DeleteRecursively(checkpoint_dir) + gfile.DeleteRecursively(checkpoint_dir2) + except OSError: + pass # Ignore + gfile.MakeDirs(checkpoint_dir) + + with tf.Graph().as_default(): + v = tf.Variable([1.0, 2.0, 3.0], name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + sess = sm.prepare_session("", init_op=tf.initialize_all_variables(), + saver=saver, checkpoint_dir=checkpoint_dir) + self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) + checkpoint_filename = os.path.join(checkpoint_dir, + "prepare_session_checkpoint") + saver.save(sess, checkpoint_filename) + # Create a new Graph and SessionManager and recover. + with tf.Graph().as_default(): + # Renames the checkpoint directory. + os.rename(checkpoint_dir, checkpoint_dir2) + gfile.MakeDirs(checkpoint_dir) + v = tf.Variable([6.0, 7.0, 8.0], name="v") + with self.test_session(): + self.assertEqual(False, tf.is_variable_initialized(v).eval()) + tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + # This should fail as there's no checkpoint within 2 seconds. + with self.assertRaisesRegexp(RuntimeError, + "no init_op or init_fn was given"): + sess = sm.prepare_session("", init_op=None, saver=saver, + checkpoint_dir=checkpoint_dir, + wait_for_checkpoint=True, max_wait_secs=2) + # Rename the checkpoint directory back. + gfile.DeleteRecursively(checkpoint_dir) + os.rename(checkpoint_dir2, checkpoint_dir) + # This should succeed as there's checkpoint. + sess = sm.prepare_session("", init_op=None, saver=saver, + checkpoint_dir=checkpoint_dir, + wait_for_checkpoint=True, max_wait_secs=2) + self.assertEqual( + True, tf.is_variable_initialized( + sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) + + def testRecoverSession(self): + # Create a checkpoint. + checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session") + try: + gfile.DeleteRecursively(checkpoint_dir) + except OSError: + pass # Ignore + gfile.MakeDirs(checkpoint_dir) + + with tf.Graph().as_default(): + v = tf.Variable(1, name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + sess, initialized = sm.recover_session("", saver=saver, + checkpoint_dir=checkpoint_dir) + self.assertFalse(initialized) + sess.run(v.initializer) + self.assertEquals(1, sess.run(v)) + saver.save(sess, os.path.join(checkpoint_dir, + "recover_session_checkpoint")) + # Create a new Graph and SessionManager and recover. + with tf.Graph().as_default(): + v = tf.Variable(2, name="v") + with self.test_session(): + self.assertEqual(False, tf.is_variable_initialized(v).eval()) + sm2 = tf.train.SessionManager( + ready_op=tf.report_uninitialized_variables()) + saver = tf.train.Saver({"v": v}) + sess, initialized = sm2.recover_session("", saver=saver, + checkpoint_dir=checkpoint_dir) + self.assertTrue(initialized) + self.assertEqual( + True, tf.is_variable_initialized( + sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) + self.assertEquals(1, sess.run(v)) + + def testWaitForSessionReturnsNoneAfterTimeout(self): + with tf.Graph().as_default(): + tf.Variable(1, name="v") + sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables(), + recovery_wait_secs=1) + + # Set max_wait_secs to allow us to try a few times. + with self.assertRaises(errors.DeadlineExceededError): + sm.wait_for_session(master="", max_wait_secs=3) + + +class ObsoleteSessionManagerTest(tf.test.TestCase): + def testPrepareSessionSucceeds(self): with tf.Graph().as_default(): v = tf.Variable([1.0, 2.0, 3.0], name="v") @@ -145,5 +266,6 @@ class SessionManagerTest(tf.test.TestCase): with self.assertRaises(errors.DeadlineExceededError): sm.wait_for_session(master="", max_wait_secs=3) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 1e1cda23b6..676209ccac 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -232,12 +232,11 @@ class Supervisor(object): default `Graph`. The supervisor may add operations to the graph before creating a session, but the graph should not be modified by the caller after passing it to the supervisor. - ready_op: `Operation` to check if the model is initialized. This - operation is run by supervisors in `prepare_or_wait_for_session()` to - check if the model is ready to use. The model is considered ready if - that operation succeeds. Defaults to the operation returned from - `tf.assert_variables_initialized()` If `None`, the model is not checked - for readiness. + ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in + `prepare_or_wait_for_session()` to check if the model is ready to use. + The model is considered ready if it returns an empty array. Defaults to + the tensor returned from `tf.report_uninitialized_variables()` If + `None`, the model is not checked for readiness. is_chief: If True, create a chief supervisor in charge of initializing and restoring the model. If False, create a supervisor that relies on a chief supervisor for inits and restore. @@ -369,16 +368,15 @@ class Supervisor(object): """Initializes ready_op. Args: - ready_op: `Operation` to check if the model is initialized. + ready_op: `Tensor` to check if the model is initialized. If it's set to USE_DEFAULT, creates an op that checks all the variables are initialized. """ if ready_op is Supervisor.USE_DEFAULT: ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP) if ready_op is None: - ready_op = variables.assert_variables_initialized() - if ready_op is not None: - ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op) + ready_op = variables.report_uninitialized_variables() + ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op) self._ready_op = ready_op def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None): diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index c4a30de5b1..9e8876eadf 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -411,7 +411,8 @@ class SupervisorTest(tf.test.TestCase): tf.Variable([4.0, 5.0, 6.0], name="w") # w will not be initialized. sv = tf.train.Supervisor(logdir=logdir, init_op=v.initializer) - with self.assertRaisesRegexp(RuntimeError, "uninitialized value w"): + with self.assertRaisesRegexp(RuntimeError, + "Variables not initialized: w"): sv.prepare_or_wait_for_session(server.target) def testInitOpFailsForTransientVariable(self): @@ -424,7 +425,8 @@ class SupervisorTest(tf.test.TestCase): collections=[tf.GraphKeys.LOCAL_VARIABLES]) # w will not be initialized. sv = tf.train.Supervisor(logdir=logdir, local_init_op=v.initializer) - with self.assertRaisesRegexp(RuntimeError, "uninitialized value w"): + with self.assertRaisesRegexp( + RuntimeError, "Variables not initialized: w"): sv.prepare_or_wait_for_session(server.target) def testSetupFail(self): -- cgit v1.2.3