diff options
author | Sourabh Bajaj <sourabhbajaj@google.com> | 2018-09-26 19:16:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 19:20:27 -0700 |
commit | 5b971c7eae5f2049a4725b16a4a44b688d3506b0 (patch) | |
tree | 3dba274f008cca389a5d54df1eada9d8efd1670a /tensorflow/python/estimator | |
parent | 51a6118e5bd85935b1d9ec0e68b92f1f98d14982 (diff) |
Fix the eval hook to run the correct number of steps when using TPU strategy
PiperOrigin-RevId: 214709465
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index eec64ad452..827b405e51 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -474,11 +474,31 @@ class Estimator(object): return _evaluate() def _convert_eval_steps_to_hooks(self, steps): + """Create hooks to run correct number of steps in evaluation. + + Args: + steps: number of steps to run during evaluation. + + Raises: + ValueError: if steps is less than or equal to zero. + + Returns: + List of hooks to be passed to the estimator. + """ if steps is None: return [] if steps <= 0: raise ValueError('Must specify steps > 0, given: {}'.format(steps)) + + # The hooks are declared as private in evaluation.py discourage the use + # by other libraries or open source users. This should be the only usage + # of the estimator evaluation hooks. + if self._eval_distribution: + steps_per_run = getattr(self._eval_distribution, 'steps_per_run', 1) + if steps_per_run > 1: + return [evaluation._MultiStepStopAfterNEvalsHook( # pylint: disable=protected-access + num_evals=steps, steps_per_run=steps_per_run)] return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access def predict(self, @@ -1474,6 +1494,7 @@ class Estimator(object): self._eval_distribution.__class__.__name__ == 'TPUStrategy') if is_tpu_strategy: + steps_per_run_variable = training.get_or_create_steps_per_run_variable() def step_fn(ctx, features, labels=None): """Runs one step of the eval computation and captures outputs.""" estimator_spec = self._eval_distribution.call_for_each_tower( @@ -1490,7 +1511,7 @@ class Estimator(object): # TODO(priyag): Fix eval step hook to account for steps_per_run. ctx = self._eval_distribution.run_steps_on_dataset( - step_fn, iterator, iterations=self._eval_distribution.steps_per_run) + step_fn, iterator, iterations=steps_per_run_variable) update_op = ctx.run_op eval_dict = ctx.non_tensor_outputs['eval_dict'] grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec'] |