diff options
Diffstat (limited to 'tensorflow/contrib/training/python/training/evaluation.py')
-rw-r--r-- | tensorflow/contrib/training/python/training/evaluation.py | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index e663a1157c..6e5ae1d1f2 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -260,9 +260,10 @@ class StopAfterNEvalsHook(session_run_hook.SessionRunHook): """ # The number of evals to run for. self._num_evals = num_evals + self._evals_completed = None - def begin(self): - self._evals_completed = get_or_create_eval_step() + def _set_evals_completed_tensor(self, updated_eval_step): + self._evals_completed = updated_eval_step def before_run(self, run_context): return session_run_hook.SessionRunArgs({ @@ -388,9 +389,16 @@ def evaluate_once(checkpoint_path, """ eval_step = get_or_create_eval_step() + # Prepare the run hooks. + hooks = hooks or [] + if eval_ops is not None: update_eval_step = state_ops.assign_add(eval_step, 1) + for h in hooks: + if isinstance(h, StopAfterNEvalsHook): + h._set_evals_completed_tensor(update_eval_step) # pylint: disable=protected-access + if isinstance(eval_ops, dict): eval_ops['update_eval_step'] = update_eval_step elif isinstance(eval_ops, (tuple, list)): @@ -408,9 +416,6 @@ def evaluate_once(checkpoint_path, master=master, config=config) - # Prepare the run hooks. - hooks = hooks or [] - final_ops_hook = basic_session_run_hooks.FinalOpsHook( final_ops, final_ops_feed_dict) hooks.append(final_ops_hook) @@ -489,9 +494,16 @@ def evaluate_repeatedly(checkpoint_dir, """ eval_step = get_or_create_eval_step() + # Prepare the run hooks. + hooks = hooks or [] + if eval_ops is not None: update_eval_step = state_ops.assign_add(eval_step, 1) + for h in hooks: + if isinstance(h, StopAfterNEvalsHook): + h._set_evals_completed_tensor(update_eval_step) # pylint: disable=protected-access + if isinstance(eval_ops, dict): eval_ops['update_eval_step'] = update_eval_step elif isinstance(eval_ops, (tuple, list)): @@ -499,9 +511,6 @@ def evaluate_repeatedly(checkpoint_dir, else: eval_ops = [eval_ops, update_eval_step] - # Prepare the run hooks. - hooks = hooks or [] - final_ops_hook = basic_session_run_hooks.FinalOpsHook( final_ops, final_ops_feed_dict) hooks.append(final_ops_hook) |