aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-09-26 19:16:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 19:20:27 -0700
commit5b971c7eae5f2049a4725b16a4a44b688d3506b0 (patch)
tree3dba274f008cca389a5d54df1eada9d8efd1670a /tensorflow/python/estimator
parent51a6118e5bd85935b1d9ec0e68b92f1f98d14982 (diff)
Fix the eval hook to run the correct number of steps when using TPU strategy
PiperOrigin-RevId: 214709465
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/estimator.py23
1 files changed, 22 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index eec64ad452..827b405e51 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -474,11 +474,31 @@ class Estimator(object):
return _evaluate()
def _convert_eval_steps_to_hooks(self, steps):
+ """Create hooks to run correct number of steps in evaluation.
+
+ Args:
+ steps: number of steps to run during evaluation.
+
+ Raises:
+ ValueError: if steps is less than or equal to zero.
+
+ Returns:
+ List of hooks to be passed to the estimator.
+ """
if steps is None:
return []
if steps <= 0:
raise ValueError('Must specify steps > 0, given: {}'.format(steps))
+
+ # The hooks are declared as private in evaluation.py discourage the use
+ # by other libraries or open source users. This should be the only usage
+ # of the estimator evaluation hooks.
+ if self._eval_distribution:
+ steps_per_run = getattr(self._eval_distribution, 'steps_per_run', 1)
+ if steps_per_run > 1:
+ return [evaluation._MultiStepStopAfterNEvalsHook( # pylint: disable=protected-access
+ num_evals=steps, steps_per_run=steps_per_run)]
return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access
def predict(self,
@@ -1474,6 +1494,7 @@ class Estimator(object):
self._eval_distribution.__class__.__name__ == 'TPUStrategy')
if is_tpu_strategy:
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
def step_fn(ctx, features, labels=None):
"""Runs one step of the eval computation and captures outputs."""
estimator_spec = self._eval_distribution.call_for_each_tower(
@@ -1490,7 +1511,7 @@ class Estimator(object):
# TODO(priyag): Fix eval step hook to account for steps_per_run.
ctx = self._eval_distribution.run_steps_on_dataset(
- step_fn, iterator, iterations=self._eval_distribution.steps_per_run)
+ step_fn, iterator, iterations=steps_per_run_variable)
update_op = ctx.run_op
eval_dict = ctx.non_tensor_outputs['eval_dict']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']