aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training/python/training/evaluation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/training/python/training/evaluation.py')
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py25
1 files changed, 17 insertions, 8 deletions
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index e663a1157c..6e5ae1d1f2 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -260,9 +260,10 @@ class StopAfterNEvalsHook(session_run_hook.SessionRunHook):
"""
# The number of evals to run for.
self._num_evals = num_evals
+ self._evals_completed = None
- def begin(self):
- self._evals_completed = get_or_create_eval_step()
+ def _set_evals_completed_tensor(self, updated_eval_step):
+ self._evals_completed = updated_eval_step
def before_run(self, run_context):
return session_run_hook.SessionRunArgs({
@@ -388,9 +389,16 @@ def evaluate_once(checkpoint_path,
"""
eval_step = get_or_create_eval_step()
+ # Prepare the run hooks.
+ hooks = hooks or []
+
if eval_ops is not None:
update_eval_step = state_ops.assign_add(eval_step, 1)
+ for h in hooks:
+ if isinstance(h, StopAfterNEvalsHook):
+ h._set_evals_completed_tensor(update_eval_step) # pylint: disable=protected-access
+
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
elif isinstance(eval_ops, (tuple, list)):
@@ -408,9 +416,6 @@ def evaluate_once(checkpoint_path,
master=master,
config=config)
- # Prepare the run hooks.
- hooks = hooks or []
-
final_ops_hook = basic_session_run_hooks.FinalOpsHook(
final_ops, final_ops_feed_dict)
hooks.append(final_ops_hook)
@@ -489,9 +494,16 @@ def evaluate_repeatedly(checkpoint_dir,
"""
eval_step = get_or_create_eval_step()
+ # Prepare the run hooks.
+ hooks = hooks or []
+
if eval_ops is not None:
update_eval_step = state_ops.assign_add(eval_step, 1)
+ for h in hooks:
+ if isinstance(h, StopAfterNEvalsHook):
+ h._set_evals_completed_tensor(update_eval_step) # pylint: disable=protected-access
+
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
elif isinstance(eval_ops, (tuple, list)):
@@ -499,9 +511,6 @@ def evaluate_repeatedly(checkpoint_dir,
else:
eval_ops = [eval_ops, update_eval_step]
- # Prepare the run hooks.
- hooks = hooks or []
-
final_ops_hook = basic_session_run_hooks.FinalOpsHook(
final_ops, final_ops_feed_dict)
hooks.append(final_ops_hook)