From 82594d38f66aa86c37ea1deb8d0631efb7d9ba96 Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Wed, 21 Dec 2016 09:03:36 -0800 Subject: Bug fix related to a corner case in evaluate_once. Corner case: * if user defined an init_op which should not run after restoring. Current implementation caused to run init_op since session_manager thinks nothing is restored. Removed variables_to_restore argument since it's usage is conflicting with Scaffold.saver. If user defined a saver and variables_to_restore, current implementation silently ignores variables_to_restore. Used specific checkpoint file feature of session-manager which is recently added. Change: 142669254 --- tensorflow/contrib/slim/python/slim/evaluation.py | 17 +++++++---- .../contrib/training/python/training/evaluation.py | 35 ++-------------------- tensorflow/python/training/monitored_session.py | 11 +++++-- .../python/training/monitored_session_test.py | 7 +++++ 4 files changed, 30 insertions(+), 40 deletions(-) diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py index b89eca46ea..231b3af502 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation.py +++ b/tensorflow/contrib/slim/python/slim/evaluation.py @@ -125,6 +125,7 @@ from __future__ import print_function from tensorflow.contrib.training.python.training import evaluation from tensorflow.python import summary from tensorflow.python.training import monitored_session +from tensorflow.python.training import saver as tf_saver __all__ = [ 'evaluate_once', @@ -192,17 +193,19 @@ def evaluate_once(master, hooks.append( evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict)) + saver = None + if variables_to_restore is not None: + saver = tf_saver.Saver(variables_to_restore) + return evaluation.evaluate_once( checkpoint_path, master=master, scaffold=monitored_session.Scaffold( - init_op=initial_op, - init_feed_dict=initial_op_feed_dict), + init_op=initial_op, init_feed_dict=initial_op_feed_dict, saver=saver), eval_ops=eval_op, feed_dict=eval_op_feed_dict, final_ops=final_op, final_ops_feed_dict=final_op_feed_dict, - variables_to_restore=variables_to_restore, hooks=hooks, config=session_config) @@ -267,17 +270,19 @@ def evaluation_loop(master, hooks.append( evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict)) + saver = None + if variables_to_restore is not None: + saver = tf_saver.Saver(variables_to_restore) + return evaluation.evaluate_repeatedly( checkpoint_dir, master=master, scaffold=monitored_session.Scaffold( - init_op=initial_op, - init_feed_dict=initial_op_feed_dict), + init_op=initial_op, init_feed_dict=initial_op_feed_dict, saver=saver), eval_ops=eval_op, feed_dict=eval_op_feed_dict, final_ops=final_op, final_ops_feed_dict=final_op_feed_dict, - variables_to_restore=variables_to_restore, eval_interval_secs=eval_interval_secs, hooks=hooks, config=session_config, diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index 4385eca770..3aaf4c3cd6 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -141,7 +141,6 @@ from __future__ import print_function import time from tensorflow.contrib.framework.python.ops import variables -from tensorflow.core.protobuf import saver_pb2 from tensorflow.python import summary from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops @@ -368,7 +367,6 @@ def evaluate_once( feed_dict=None, final_ops=None, final_ops_feed_dict=None, - variables_to_restore=None, hooks=None, config=None): """Evaluates the model at the given checkpoint path. @@ -407,9 +405,6 @@ def evaluate_once( final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to `Tensors`. final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. - variables_to_restore: A list of TensorFlow variables to restore during - evaluation. If the argument is left as `None` then - tf.contrib.framework.get_variables_to_restore() is used. hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the evaluation loop. config: An instance of `tf.ConfigProto` that will be used to @@ -430,24 +425,13 @@ def evaluate_once( else: eval_ops = [eval_ops, update_eval_step] - # Must come before the scaffold check. - if scaffold and scaffold.saver: - saver = scaffold.saver - else: - saver = tf_saver.Saver( - variables_to_restore or variables.get_variables_to_restore(), - write_version=saver_pb2.SaverDef.V2) - - scaffold = scaffold or monitored_session.Scaffold() - scaffold = _scaffold_with_init(scaffold, saver, checkpoint_path) - logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) # Prepare the session creator. session_creator = monitored_session.ChiefSessionCreator( scaffold=scaffold, - checkpoint_dir=None, + checkpoint_filename_with_path=checkpoint_path, master=master, config=config) @@ -476,7 +460,6 @@ def evaluate_repeatedly( feed_dict=None, final_ops=None, final_ops_feed_dict=None, - variables_to_restore=None, eval_interval_secs=60, hooks=None, config=None, @@ -518,9 +501,6 @@ def evaluate_repeatedly( final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to `Tensors`. final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. - variables_to_restore: A list of TensorFlow variables to restore during - evaluation. If the argument is left as `None` then - tf.contrib.framework.get_variables_to_restore() is used. eval_interval_secs: The minimum number of seconds between evaluations. hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the evaluation loop. @@ -546,15 +526,6 @@ def evaluate_repeatedly( else: eval_ops = [eval_ops, update_eval_step] - # Must come before the scaffold check. - if scaffold and scaffold.saver: - saver = scaffold.saver - else: - saver = tf_saver.Saver( - variables_to_restore or variables.get_variables_to_restore()) - - scaffold = scaffold or monitored_session.Scaffold() - # Prepare the run hooks. hooks = hooks or [] @@ -566,8 +537,8 @@ def evaluate_repeatedly( checkpoint_dir, eval_interval_secs, timeout): session_creator = monitored_session.ChiefSessionCreator( - scaffold=_scaffold_with_init(scaffold, saver, checkpoint_path), - checkpoint_dir=None, + scaffold=scaffold, + checkpoint_filename_with_path=checkpoint_path, master=master, config=config) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 7c273c7c09..ea763a1540 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -327,8 +327,12 @@ class SessionCreator(object): class ChiefSessionCreator(SessionCreator): """Creates a tf.Session for a chief.""" - def __init__(self, scaffold=None, master='', config=None, - checkpoint_dir=None): + def __init__(self, + scaffold=None, + master='', + config=None, + checkpoint_dir=None, + checkpoint_filename_with_path=None): """Initializes a chief session creator. Args: @@ -338,8 +342,10 @@ class ChiefSessionCreator(SessionCreator): config: `ConfigProto` proto used to configure the session. checkpoint_dir: A string. Optional path to a directory where to restore variables. + checkpoint_filename_with_path: Full file name path to the checkpoint file. """ self._checkpoint_dir = checkpoint_dir + self._checkpoint_filename_with_path = checkpoint_filename_with_path self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master @@ -362,6 +368,7 @@ class ChiefSessionCreator(SessionCreator): self._master, saver=self._scaffold.saver, checkpoint_dir=self._checkpoint_dir, + checkpoint_filename_with_path=self._checkpoint_filename_with_path, config=self._config, init_op=self._scaffold.init_op, init_feed_dict=self._scaffold.init_feed_dict, diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index dafe6de363..4c810b95bb 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -829,6 +829,13 @@ class MonitoredSessionTest(test.TestCase): session_creator=monitored_session.ChiefSessionCreator( scaffold, checkpoint_dir=logdir)) as session: self.assertEqual(2, session.run(gstep)) + # A restart will find the checkpoint and recover automatically. + with monitored_session.MonitoredSession( + session_creator=monitored_session.ChiefSessionCreator( + scaffold, + checkpoint_filename_with_path=saver_lib.latest_checkpoint( + logdir))) as session: + self.assertEqual(2, session.run(gstep)) def test_retry_on_aborted_error(self): # Tests that we silently retry on abort. Note that this does not test -- cgit v1.2.3