aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-12-21 09:03:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 09:25:27 -0800
commit82594d38f66aa86c37ea1deb8d0631efb7d9ba96 (patch)
tree1362cafc82e8cec5a46166dfd9b86926ecc7c196
parent6ecaae486ecb12ad693a85e1bc3aa80b5afab418 (diff)
Bug fix related to a corner case in evaluate_once. Corner case:
* if user defined an init_op which should not run after restoring. Current implementation caused to run init_op since session_manager thinks nothing is restored. Removed variables_to_restore argument since it's usage is conflicting with Scaffold.saver. If user defined a saver and variables_to_restore, current implementation silently ignores variables_to_restore. Used specific checkpoint file feature of session-manager which is recently added. Change: 142669254
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation.py17
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py35
-rw-r--r--tensorflow/python/training/monitored_session.py11
-rw-r--r--tensorflow/python/training/monitored_session_test.py7
4 files changed, 30 insertions, 40 deletions
diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py
index b89eca46ea..231b3af502 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation.py
@@ -125,6 +125,7 @@ from __future__ import print_function
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.python import summary
from tensorflow.python.training import monitored_session
+from tensorflow.python.training import saver as tf_saver
__all__ = [
'evaluate_once',
@@ -192,17 +193,19 @@ def evaluate_once(master,
hooks.append(
evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict))
+ saver = None
+ if variables_to_restore is not None:
+ saver = tf_saver.Saver(variables_to_restore)
+
return evaluation.evaluate_once(
checkpoint_path,
master=master,
scaffold=monitored_session.Scaffold(
- init_op=initial_op,
- init_feed_dict=initial_op_feed_dict),
+ init_op=initial_op, init_feed_dict=initial_op_feed_dict, saver=saver),
eval_ops=eval_op,
feed_dict=eval_op_feed_dict,
final_ops=final_op,
final_ops_feed_dict=final_op_feed_dict,
- variables_to_restore=variables_to_restore,
hooks=hooks,
config=session_config)
@@ -267,17 +270,19 @@ def evaluation_loop(master,
hooks.append(
evaluation.SummaryAtEndHook(logdir, summary_op, summary_op_feed_dict))
+ saver = None
+ if variables_to_restore is not None:
+ saver = tf_saver.Saver(variables_to_restore)
+
return evaluation.evaluate_repeatedly(
checkpoint_dir,
master=master,
scaffold=monitored_session.Scaffold(
- init_op=initial_op,
- init_feed_dict=initial_op_feed_dict),
+ init_op=initial_op, init_feed_dict=initial_op_feed_dict, saver=saver),
eval_ops=eval_op,
feed_dict=eval_op_feed_dict,
final_ops=final_op,
final_ops_feed_dict=final_op_feed_dict,
- variables_to_restore=variables_to_restore,
eval_interval_secs=eval_interval_secs,
hooks=hooks,
config=session_config,
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index 4385eca770..3aaf4c3cd6 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -141,7 +141,6 @@ from __future__ import print_function
import time
from tensorflow.contrib.framework.python.ops import variables
-from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import summary
from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
@@ -368,7 +367,6 @@ def evaluate_once(
feed_dict=None,
final_ops=None,
final_ops_feed_dict=None,
- variables_to_restore=None,
hooks=None,
config=None):
"""Evaluates the model at the given checkpoint path.
@@ -407,9 +405,6 @@ def evaluate_once(
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
- variables_to_restore: A list of TensorFlow variables to restore during
- evaluation. If the argument is left as `None` then
- tf.contrib.framework.get_variables_to_restore() is used.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
evaluation loop.
config: An instance of `tf.ConfigProto` that will be used to
@@ -430,24 +425,13 @@ def evaluate_once(
else:
eval_ops = [eval_ops, update_eval_step]
- # Must come before the scaffold check.
- if scaffold and scaffold.saver:
- saver = scaffold.saver
- else:
- saver = tf_saver.Saver(
- variables_to_restore or variables.get_variables_to_restore(),
- write_version=saver_pb2.SaverDef.V2)
-
- scaffold = scaffold or monitored_session.Scaffold()
- scaffold = _scaffold_with_init(scaffold, saver, checkpoint_path)
-
logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
time.gmtime()))
# Prepare the session creator.
session_creator = monitored_session.ChiefSessionCreator(
scaffold=scaffold,
- checkpoint_dir=None,
+ checkpoint_filename_with_path=checkpoint_path,
master=master,
config=config)
@@ -476,7 +460,6 @@ def evaluate_repeatedly(
feed_dict=None,
final_ops=None,
final_ops_feed_dict=None,
- variables_to_restore=None,
eval_interval_secs=60,
hooks=None,
config=None,
@@ -518,9 +501,6 @@ def evaluate_repeatedly(
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
- variables_to_restore: A list of TensorFlow variables to restore during
- evaluation. If the argument is left as `None` then
- tf.contrib.framework.get_variables_to_restore() is used.
eval_interval_secs: The minimum number of seconds between evaluations.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
evaluation loop.
@@ -546,15 +526,6 @@ def evaluate_repeatedly(
else:
eval_ops = [eval_ops, update_eval_step]
- # Must come before the scaffold check.
- if scaffold and scaffold.saver:
- saver = scaffold.saver
- else:
- saver = tf_saver.Saver(
- variables_to_restore or variables.get_variables_to_restore())
-
- scaffold = scaffold or monitored_session.Scaffold()
-
# Prepare the run hooks.
hooks = hooks or []
@@ -566,8 +537,8 @@ def evaluate_repeatedly(
checkpoint_dir, eval_interval_secs, timeout):
session_creator = monitored_session.ChiefSessionCreator(
- scaffold=_scaffold_with_init(scaffold, saver, checkpoint_path),
- checkpoint_dir=None,
+ scaffold=scaffold,
+ checkpoint_filename_with_path=checkpoint_path,
master=master,
config=config)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 7c273c7c09..ea763a1540 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -327,8 +327,12 @@ class SessionCreator(object):
class ChiefSessionCreator(SessionCreator):
"""Creates a tf.Session for a chief."""
- def __init__(self, scaffold=None, master='', config=None,
- checkpoint_dir=None):
+ def __init__(self,
+ scaffold=None,
+ master='',
+ config=None,
+ checkpoint_dir=None,
+ checkpoint_filename_with_path=None):
"""Initializes a chief session creator.
Args:
@@ -338,8 +342,10 @@ class ChiefSessionCreator(SessionCreator):
config: `ConfigProto` proto used to configure the session.
checkpoint_dir: A string. Optional path to a directory where to restore
variables.
+ checkpoint_filename_with_path: Full file name path to the checkpoint file.
"""
self._checkpoint_dir = checkpoint_dir
+ self._checkpoint_filename_with_path = checkpoint_filename_with_path
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
@@ -362,6 +368,7 @@ class ChiefSessionCreator(SessionCreator):
self._master,
saver=self._scaffold.saver,
checkpoint_dir=self._checkpoint_dir,
+ checkpoint_filename_with_path=self._checkpoint_filename_with_path,
config=self._config,
init_op=self._scaffold.init_op,
init_feed_dict=self._scaffold.init_feed_dict,
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index dafe6de363..4c810b95bb 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -829,6 +829,13 @@ class MonitoredSessionTest(test.TestCase):
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir)) as session:
self.assertEqual(2, session.run(gstep))
+ # A restart will find the checkpoint and recover automatically.
+ with monitored_session.MonitoredSession(
+ session_creator=monitored_session.ChiefSessionCreator(
+ scaffold,
+ checkpoint_filename_with_path=saver_lib.latest_checkpoint(
+ logdir))) as session:
+ self.assertEqual(2, session.run(gstep))
def test_retry_on_aborted_error(self):
# Tests that we silently retry on abort. Note that this does not test