aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Sherry Moore <sherrym@google.com>2016-05-03 12:05:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-03 13:11:44 -0700
commitb6a7b7dc5ed5e80051fe005445fa4663f67045d7 (patch)
tree8e1600a33ab5b6f470d0db928d203a1b2b7097d8 /tensorflow/python
parent768b499811932904b897e4d00bee7ff3853416f5 (diff)
Added a new op report_uninitialized_variables() which returns a 1-D tensor
containing names of the uninitialized variables when run. Supervisor's ready_op is now implemented with report_uninitialized_variables(). If you write custom ready_op, please make sure 1. If the model is ready, it returns an empty tensor. 2. If the model is not ready, it returns a 1-D tensor containing reasons why the model is not ready. assert_variables_initialized() will be kept for 6 months for backward compatibility. Change: 121407132
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py29
-rw-r--r--tensorflow/python/ops/state_ops.py1
-rw-r--r--tensorflow/python/ops/variables.py49
-rw-r--r--tensorflow/python/training/session_manager.py26
-rw-r--r--tensorflow/python/training/session_manager_test.py122
-rw-r--r--tensorflow/python/training/supervisor.py18
-rw-r--r--tensorflow/python/training/supervisor_test.py6
7 files changed, 232 insertions, 19 deletions
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 3949c77c2a..ed29409ee3 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -361,6 +361,35 @@ class VariablesTestCase(tf.test.TestCase):
class IsInitializedTest(tf.test.TestCase):
def testNoVars(self):
+ with tf.Graph().as_default(), self.test_session() as sess:
+ uninited = tf.report_uninitialized_variables()
+ self.assertEqual(0, sess.run(uninited).size)
+
+ def testAssertVariablesInitialized(self):
+ with tf.Graph().as_default(), self.test_session() as sess:
+ v = tf.Variable([1, 2], name="v")
+ w = tf.Variable([3, 4], name="w")
+ _ = v, w
+ uninited = tf.report_uninitialized_variables()
+ self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited))
+ tf.initialize_all_variables().run()
+ self.assertEqual(0, sess.run(uninited).size)
+
+ def testVariableList(self):
+ with tf.Graph().as_default(), self.test_session() as sess:
+ v = tf.Variable([1, 2], name="v")
+ w = tf.Variable([3, 4], name="w")
+ uninited = tf.report_uninitialized_variables()
+ self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited))
+ sess.run(w.initializer)
+ self.assertAllEqual(np.array([b"v"]), sess.run(uninited))
+ v.initializer.run()
+ self.assertEqual(0, sess.run(uninited).size)
+
+
+class ObsoleteIsInitializedTest(tf.test.TestCase):
+
+ def testNoVars(self):
with tf.Graph().as_default():
self.assertEqual(None, tf.assert_variables_initialized())
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 78d675aaad..19c2878b1e 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -31,6 +31,7 @@ collected in the graph.
@@initialize_variables
@@initialize_local_variables
@@is_variable_initialized
+@@report_uninitialized_variables
@@assert_variables_initialized
## Saving and Restoring Variables
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 0f7bac4c86..328fdc6503 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -19,9 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import variable_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
@@ -841,13 +843,14 @@ def initialize_local_variables():
def is_variable_initialized(variable):
- """Returns an Op to check if a variable has been initialized.
+ """Tests if a variable has been initialized.
Args:
variable: A `Variable`.
Returns:
- An operation to check whether a variable has been initialized.
+ Returns a scalar boolean Tensor, `True` if the variable has been
+ initialized, `False` otherwise.
"""
return state_ops.is_variable_initialized(variable)
@@ -855,6 +858,9 @@ def is_variable_initialized(variable):
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
+ NOTE: This function is obsolete and will be removed in 6 months. Please
+ change your implementation to use `report_uninitialized_variables()`.
+
When run, the returned Op will raise the exception `FailedPreconditionError`
if any of the variables has not yet been initialized.
@@ -890,6 +896,45 @@ def assert_variables_initialized(var_list=None):
return array_ops.pack(ranks)
+def report_uninitialized_variables(var_list=None,
+ name="report_uninitialized_variables"):
+ """Adds ops to list the names of uninitialized variables.
+
+ When run, it returns a 1-D tensor containing the names of uninitialized
+ variables if there are any, or an empty array if there are none.
+
+ Args:
+ var_list: List of `Variable` objects to check. Defaults to the
+ value of `all_variables() + local_variables()`
+ name: Optional name of the `Operation`.
+
+ Returns:
+ A 1-D tensor containing names of the unintialized variables, or an empty 1-D
+ tensor if there are no variables or no uninitialized variables.
+ """
+ if var_list is None:
+ var_list = all_variables() + local_variables()
+ # Backwards compatibility for old-style variables. TODO(touts): remove.
+ if not var_list:
+ var_list = []
+ for op in ops.get_default_graph().get_operations():
+ if op.type in ["Variable", "AutoReloadVariable"]:
+ var_list.append(op.outputs[0])
+ if not var_list:
+ # Return an empty tensor so we only need to check for returned tensor
+ # size being 0 as an indication of model ready.
+ return array_ops.constant([], dtype=dtypes.string, name=name)
+ else:
+ # Get a 1-D boolean tensor listing whether each variable is initialized.
+ variables_mask = math_ops.logical_not(array_ops.pack(
+ [state_ops.is_variable_initialized(v) for v in var_list]))
+ # Get a 1-D string tensor containing all the variable names.
+ variable_names_tensor = array_ops.constant([s.op.name for s in var_list])
+ # Return a 1-D tensor containing all the names of uninitialized variables.
+ return array_ops.boolean_mask(variable_names_tensor, variables_mask,
+ name=name)
+
+
# pylint: disable=protected-access
ops.register_tensor_conversion_function(Variable,
Variable._TensorConversionFunction)
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index f6b88bd872..d418604499 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import threading
import time
+import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import errors
@@ -84,9 +85,12 @@ class SessionManager(object):
The `local_init_op` is an `Operation` that is run always after a new session
was created. If `None`, this step is skipped.
- The `ready_op` is an `Operation`. The model is considered ready
- if that operation succeeds. If `None`, the model is not checked
- for readiness.
+ The `ready_op` is an `Operation` used to check if the model is ready. The
+ model is considered ready if that operation returns an empty string tensor.
+ If the operation returns non empty string tensor, the elements are
+ concatenated and used to indicate to the user why the model is not ready.
+
+ If `ready_op` is `None`, the model is not checked for readiness.
`recovery_wait_secs` is the number of seconds between checks that
the model is ready. It is used by processes to wait for a model to
@@ -325,8 +329,20 @@ class SessionManager(object):
return None
else:
try:
- sess.run(self._ready_op)
- return None
+ ready_value = sess.run(self._ready_op)
+ # The model is considered ready if ready_op returns an empty 1-D tensor.
+ # Also compare to `None` and dtype being int32 for backward
+ # compatibility.
+ if (ready_value is None or ready_value.dtype == np.int32 or
+ ready_value.size == 0):
+ return None
+ else:
+ # TODO(sherrym): If a custom ready_op returns other types of tensor,
+ # or strings other than variable names, this message could be
+ # confusing.
+ non_initialized_varnames = ", ".join(
+ [i.decode("utf-8") for i in ready_value])
+ return "Variables not initialized: " + non_initialized_varnames
except errors.FailedPreconditionError as e:
if "uninitialized" not in str(e):
logging.warning("Model not ready raised: %s", str(e))
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index 9308548dd2..01cc4f11d6 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -31,6 +31,127 @@ class SessionManagerTest(tf.test.TestCase):
def testPrepareSessionSucceeds(self):
with tf.Graph().as_default():
v = tf.Variable([1.0, 2.0, 3.0], name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ sess = sm.prepare_session("", init_op=tf.initialize_all_variables())
+ self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
+
+ def testPrepareSessionSucceedsWithInitFeedDict(self):
+ with tf.Graph().as_default():
+ p = tf.placeholder(tf.float32, shape=(3,))
+ v = tf.Variable(p, name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ sess = sm.prepare_session("",
+ init_op=tf.initialize_all_variables(),
+ init_feed_dict={p: [1.0, 2.0, 3.0]})
+ self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
+
+ def testPrepareSessionSucceedsWithInitFn(self):
+ with tf.Graph().as_default():
+ v = tf.Variable([125], name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ sess = sm.prepare_session("",
+ init_fn=lambda sess: sess.run(v.initializer))
+ self.assertAllClose([125], sess.run(v))
+
+ def testPrepareSessionFails(self):
+ checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
+ checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
+ try:
+ gfile.DeleteRecursively(checkpoint_dir)
+ gfile.DeleteRecursively(checkpoint_dir2)
+ except OSError:
+ pass # Ignore
+ gfile.MakeDirs(checkpoint_dir)
+
+ with tf.Graph().as_default():
+ v = tf.Variable([1.0, 2.0, 3.0], name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ sess = sm.prepare_session("", init_op=tf.initialize_all_variables(),
+ saver=saver, checkpoint_dir=checkpoint_dir)
+ self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
+ checkpoint_filename = os.path.join(checkpoint_dir,
+ "prepare_session_checkpoint")
+ saver.save(sess, checkpoint_filename)
+ # Create a new Graph and SessionManager and recover.
+ with tf.Graph().as_default():
+ # Renames the checkpoint directory.
+ os.rename(checkpoint_dir, checkpoint_dir2)
+ gfile.MakeDirs(checkpoint_dir)
+ v = tf.Variable([6.0, 7.0, 8.0], name="v")
+ with self.test_session():
+ self.assertEqual(False, tf.is_variable_initialized(v).eval())
+ tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ # This should fail as there's no checkpoint within 2 seconds.
+ with self.assertRaisesRegexp(RuntimeError,
+ "no init_op or init_fn was given"):
+ sess = sm.prepare_session("", init_op=None, saver=saver,
+ checkpoint_dir=checkpoint_dir,
+ wait_for_checkpoint=True, max_wait_secs=2)
+ # Rename the checkpoint directory back.
+ gfile.DeleteRecursively(checkpoint_dir)
+ os.rename(checkpoint_dir2, checkpoint_dir)
+ # This should succeed as there's checkpoint.
+ sess = sm.prepare_session("", init_op=None, saver=saver,
+ checkpoint_dir=checkpoint_dir,
+ wait_for_checkpoint=True, max_wait_secs=2)
+ self.assertEqual(
+ True, tf.is_variable_initialized(
+ sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+
+ def testRecoverSession(self):
+ # Create a checkpoint.
+ checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
+ try:
+ gfile.DeleteRecursively(checkpoint_dir)
+ except OSError:
+ pass # Ignore
+ gfile.MakeDirs(checkpoint_dir)
+
+ with tf.Graph().as_default():
+ v = tf.Variable(1, name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ sess, initialized = sm.recover_session("", saver=saver,
+ checkpoint_dir=checkpoint_dir)
+ self.assertFalse(initialized)
+ sess.run(v.initializer)
+ self.assertEquals(1, sess.run(v))
+ saver.save(sess, os.path.join(checkpoint_dir,
+ "recover_session_checkpoint"))
+ # Create a new Graph and SessionManager and recover.
+ with tf.Graph().as_default():
+ v = tf.Variable(2, name="v")
+ with self.test_session():
+ self.assertEqual(False, tf.is_variable_initialized(v).eval())
+ sm2 = tf.train.SessionManager(
+ ready_op=tf.report_uninitialized_variables())
+ saver = tf.train.Saver({"v": v})
+ sess, initialized = sm2.recover_session("", saver=saver,
+ checkpoint_dir=checkpoint_dir)
+ self.assertTrue(initialized)
+ self.assertEqual(
+ True, tf.is_variable_initialized(
+ sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+ self.assertEquals(1, sess.run(v))
+
+ def testWaitForSessionReturnsNoneAfterTimeout(self):
+ with tf.Graph().as_default():
+ tf.Variable(1, name="v")
+ sm = tf.train.SessionManager(ready_op=tf.report_uninitialized_variables(),
+ recovery_wait_secs=1)
+
+ # Set max_wait_secs to allow us to try a few times.
+ with self.assertRaises(errors.DeadlineExceededError):
+ sm.wait_for_session(master="", max_wait_secs=3)
+
+
+class ObsoleteSessionManagerTest(tf.test.TestCase):
+
+ def testPrepareSessionSucceeds(self):
+ with tf.Graph().as_default():
+ v = tf.Variable([1.0, 2.0, 3.0], name="v")
sm = tf.train.SessionManager(ready_op=tf.assert_variables_initialized())
sess = sm.prepare_session("", init_op=tf.initialize_all_variables())
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@@ -145,5 +266,6 @@ class SessionManagerTest(tf.test.TestCase):
with self.assertRaises(errors.DeadlineExceededError):
sm.wait_for_session(master="", max_wait_secs=3)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 1e1cda23b6..676209ccac 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -232,12 +232,11 @@ class Supervisor(object):
default `Graph`. The supervisor may add operations to the graph before
creating a session, but the graph should not be modified by the caller
after passing it to the supervisor.
- ready_op: `Operation` to check if the model is initialized. This
- operation is run by supervisors in `prepare_or_wait_for_session()` to
- check if the model is ready to use. The model is considered ready if
- that operation succeeds. Defaults to the operation returned from
- `tf.assert_variables_initialized()` If `None`, the model is not checked
- for readiness.
+ ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in
+ `prepare_or_wait_for_session()` to check if the model is ready to use.
+ The model is considered ready if it returns an empty array. Defaults to
+ the tensor returned from `tf.report_uninitialized_variables()` If
+ `None`, the model is not checked for readiness.
is_chief: If True, create a chief supervisor in charge of initializing
and restoring the model. If False, create a supervisor that relies
on a chief supervisor for inits and restore.
@@ -369,16 +368,15 @@ class Supervisor(object):
"""Initializes ready_op.
Args:
- ready_op: `Operation` to check if the model is initialized.
+ ready_op: `Tensor` to check if the model is initialized.
If it's set to USE_DEFAULT, creates an op that checks all
the variables are initialized.
"""
if ready_op is Supervisor.USE_DEFAULT:
ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
if ready_op is None:
- ready_op = variables.assert_variables_initialized()
- if ready_op is not None:
- ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
+ ready_op = variables.report_uninitialized_variables()
+ ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
self._ready_op = ready_op
def _init_init_op(self, init_op=USE_DEFAULT, init_feed_dict=None):
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index c4a30de5b1..9e8876eadf 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -411,7 +411,8 @@ class SupervisorTest(tf.test.TestCase):
tf.Variable([4.0, 5.0, 6.0], name="w")
# w will not be initialized.
sv = tf.train.Supervisor(logdir=logdir, init_op=v.initializer)
- with self.assertRaisesRegexp(RuntimeError, "uninitialized value w"):
+ with self.assertRaisesRegexp(RuntimeError,
+ "Variables not initialized: w"):
sv.prepare_or_wait_for_session(server.target)
def testInitOpFailsForTransientVariable(self):
@@ -424,7 +425,8 @@ class SupervisorTest(tf.test.TestCase):
collections=[tf.GraphKeys.LOCAL_VARIABLES])
# w will not be initialized.
sv = tf.train.Supervisor(logdir=logdir, local_init_op=v.initializer)
- with self.assertRaisesRegexp(RuntimeError, "uninitialized value w"):
+ with self.assertRaisesRegexp(
+ RuntimeError, "Variables not initialized: w"):
sv.prepare_or_wait_for_session(server.target)
def testSetupFail(self):