diff options
author | Sourabh Bajaj <sourabhbajaj@google.com> | 2018-09-26 19:16:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 19:20:27 -0700 |
commit | 5b971c7eae5f2049a4725b16a4a44b688d3506b0 (patch) | |
tree | 3dba274f008cca389a5d54df1eada9d8efd1670a /tensorflow/python/training | |
parent | 51a6118e5bd85935b1d9ec0e68b92f1f98d14982 (diff) |
Fix the eval hook to run the correct number of steps when using TPU strategy
PiperOrigin-RevId: 214709465
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/basic_session_run_hooks.py | 5 | ||||
-rw-r--r-- | tensorflow/python/training/evaluation.py | 68 |
2 files changed, 68 insertions, 5 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 3bd4bd75bd..1efabcd854 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -344,7 +344,7 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): raise ValueError("steps_per_run should be greater than 0") self._num_steps = num_steps self._last_step = last_step - self._steps_per_run = steps_per_run + self._steps_per_run_initial_value = steps_per_run def begin(self): self._global_step_tensor = training_util.get_global_step() @@ -353,7 +353,8 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): self._steps_per_run_variable = get_or_create_steps_per_run_variable() def _update_steps_per_run_variable(self, global_step, session): - steps = min(self._last_step - global_step, self._steps_per_run) + steps = min(self._last_step - global_step, + self._steps_per_run_initial_value) self._steps_per_run_variable.load(steps, session=session) def after_create_session(self, session, coord): diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py index b36444a14c..2c4eb02d53 100644 --- a/tensorflow/python/training/evaluation.py +++ b/tensorflow/python/training/evaluation.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time import math +import time from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -77,6 +78,59 @@ def _get_latest_eval_step_value(update_ops): return array_ops.identity(_get_or_create_eval_step().read_value()) +class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook): + """Run hook used by the evaluation routines to run the `eval_ops` N times.""" + + def __init__(self, num_evals, steps_per_run=1): + """Constructs the run hook. + + Args: + num_evals: The number of evaluations to run for. if set to None, will + iterate the dataset until all inputs are exhausted. + steps_per_run: Number of steps executed per run call. + """ + self._num_evals = num_evals + self._evals_completed = None + self._steps_per_run_initial_value = steps_per_run + + def _set_evals_completed_tensor(self, updated_eval_step): + self._evals_completed = updated_eval_step + + def begin(self): + self._steps_per_run_variable = \ + basic_session_run_hooks.get_or_create_steps_per_run_variable() + + def after_create_session(self, session, coord): + # Update number of steps to run in the first run call + if self._num_evals is None: + steps = self._steps_per_run_initial_value + else: + steps = min(self._steps_per_run_initial_value, self._num_evals) + self._steps_per_run_variable.load(steps, session=session) + + def before_run(self, run_context): + return session_run_hook.SessionRunArgs({ + 'evals_completed': self._evals_completed + }) + + def after_run(self, run_context, run_values): + evals_completed = run_values.results['evals_completed'] + # Update number of steps to run in the next iteration + if self._num_evals is None: + steps = self._steps_per_run_initial_value + else: + steps = min(self._num_evals - evals_completed, + self._steps_per_run_initial_value) + self._steps_per_run_variable.load(steps, session=run_context.session) + + if self._num_evals is None: + logging.info('Evaluation [%d]', evals_completed) + else: + logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) + if self._num_evals is not None and evals_completed >= self._num_evals: + run_context.request_stop() + + class _StopAfterNEvalsHook(session_run_hook.SessionRunHook): """Run hook used by the evaluation routines to run the `eval_ops` N times.""" @@ -176,7 +230,15 @@ def _evaluate_once(checkpoint_path, hooks = list(hooks or []) if eval_ops is not None: - update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True) + if any([isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks]): + steps_per_run_variable = \ + basic_session_run_hooks.get_or_create_steps_per_run_variable() + update_eval_step = state_ops.assign_add( + eval_step, + math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype), + use_locking=True) + else: + update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True) if isinstance(eval_ops, dict): eval_ops['update_eval_step'] = update_eval_step @@ -188,7 +250,7 @@ def _evaluate_once(checkpoint_path, eval_step_value = _get_latest_eval_step_value(eval_ops) for h in hooks: - if isinstance(h, _StopAfterNEvalsHook): + if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)): h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', |