diff options
author | Sourabh Bajaj <sourabhbajaj@google.com> | 2018-09-26 19:16:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 19:20:27 -0700 |
commit | 5b971c7eae5f2049a4725b16a4a44b688d3506b0 (patch) | |
tree | 3dba274f008cca389a5d54df1eada9d8efd1670a /tensorflow | |
parent | 51a6118e5bd85935b1d9ec0e68b92f1f98d14982 (diff) |
Fix the eval hook to run the correct number of steps when using TPU strategy
PiperOrigin-RevId: 214709465
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 23 | ||||
-rw-r--r-- | tensorflow/python/training/basic_session_run_hooks.py | 5 | ||||
-rw-r--r-- | tensorflow/python/training/evaluation.py | 68 |
3 files changed, 90 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index eec64ad452..827b405e51 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -474,11 +474,31 @@ class Estimator(object): return _evaluate() def _convert_eval_steps_to_hooks(self, steps): + """Create hooks to run correct number of steps in evaluation. + + Args: + steps: number of steps to run during evaluation. + + Raises: + ValueError: if steps is less than or equal to zero. + + Returns: + List of hooks to be passed to the estimator. + """ if steps is None: return [] if steps <= 0: raise ValueError('Must specify steps > 0, given: {}'.format(steps)) + + # The hooks are declared as private in evaluation.py discourage the use + # by other libraries or open source users. This should be the only usage + # of the estimator evaluation hooks. + if self._eval_distribution: + steps_per_run = getattr(self._eval_distribution, 'steps_per_run', 1) + if steps_per_run > 1: + return [evaluation._MultiStepStopAfterNEvalsHook( # pylint: disable=protected-access + num_evals=steps, steps_per_run=steps_per_run)] return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access def predict(self, @@ -1474,6 +1494,7 @@ class Estimator(object): self._eval_distribution.__class__.__name__ == 'TPUStrategy') if is_tpu_strategy: + steps_per_run_variable = training.get_or_create_steps_per_run_variable() def step_fn(ctx, features, labels=None): """Runs one step of the eval computation and captures outputs.""" estimator_spec = self._eval_distribution.call_for_each_tower( @@ -1490,7 +1511,7 @@ class Estimator(object): # TODO(priyag): Fix eval step hook to account for steps_per_run. ctx = self._eval_distribution.run_steps_on_dataset( - step_fn, iterator, iterations=self._eval_distribution.steps_per_run) + step_fn, iterator, iterations=steps_per_run_variable) update_op = ctx.run_op eval_dict = ctx.non_tensor_outputs['eval_dict'] grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec'] diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 3bd4bd75bd..1efabcd854 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -344,7 +344,7 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): raise ValueError("steps_per_run should be greater than 0") self._num_steps = num_steps self._last_step = last_step - self._steps_per_run = steps_per_run + self._steps_per_run_initial_value = steps_per_run def begin(self): self._global_step_tensor = training_util.get_global_step() @@ -353,7 +353,8 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): self._steps_per_run_variable = get_or_create_steps_per_run_variable() def _update_steps_per_run_variable(self, global_step, session): - steps = min(self._last_step - global_step, self._steps_per_run) + steps = min(self._last_step - global_step, + self._steps_per_run_initial_value) self._steps_per_run_variable.load(steps, session=session) def after_create_session(self, session, coord): diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py index b36444a14c..2c4eb02d53 100644 --- a/tensorflow/python/training/evaluation.py +++ b/tensorflow/python/training/evaluation.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time import math +import time from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -77,6 +78,59 @@ def _get_latest_eval_step_value(update_ops): return array_ops.identity(_get_or_create_eval_step().read_value()) +class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook): + """Run hook used by the evaluation routines to run the `eval_ops` N times.""" + + def __init__(self, num_evals, steps_per_run=1): + """Constructs the run hook. + + Args: + num_evals: The number of evaluations to run for. if set to None, will + iterate the dataset until all inputs are exhausted. + steps_per_run: Number of steps executed per run call. + """ + self._num_evals = num_evals + self._evals_completed = None + self._steps_per_run_initial_value = steps_per_run + + def _set_evals_completed_tensor(self, updated_eval_step): + self._evals_completed = updated_eval_step + + def begin(self): + self._steps_per_run_variable = \ + basic_session_run_hooks.get_or_create_steps_per_run_variable() + + def after_create_session(self, session, coord): + # Update number of steps to run in the first run call + if self._num_evals is None: + steps = self._steps_per_run_initial_value + else: + steps = min(self._steps_per_run_initial_value, self._num_evals) + self._steps_per_run_variable.load(steps, session=session) + + def before_run(self, run_context): + return session_run_hook.SessionRunArgs({ + 'evals_completed': self._evals_completed + }) + + def after_run(self, run_context, run_values): + evals_completed = run_values.results['evals_completed'] + # Update number of steps to run in the next iteration + if self._num_evals is None: + steps = self._steps_per_run_initial_value + else: + steps = min(self._num_evals - evals_completed, + self._steps_per_run_initial_value) + self._steps_per_run_variable.load(steps, session=run_context.session) + + if self._num_evals is None: + logging.info('Evaluation [%d]', evals_completed) + else: + logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) + if self._num_evals is not None and evals_completed >= self._num_evals: + run_context.request_stop() + + class _StopAfterNEvalsHook(session_run_hook.SessionRunHook): """Run hook used by the evaluation routines to run the `eval_ops` N times.""" @@ -176,7 +230,15 @@ def _evaluate_once(checkpoint_path, hooks = list(hooks or []) if eval_ops is not None: - update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True) + if any([isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks]): + steps_per_run_variable = \ + basic_session_run_hooks.get_or_create_steps_per_run_variable() + update_eval_step = state_ops.assign_add( + eval_step, + math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype), + use_locking=True) + else: + update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True) if isinstance(eval_ops, dict): eval_ops['update_eval_step'] = update_eval_step @@ -188,7 +250,7 @@ def _evaluate_once(checkpoint_path, eval_step_value = _get_latest_eval_step_value(eval_ops) for h in hooks: - if isinstance(h, _StopAfterNEvalsHook): + if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)): h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S', |