aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-09-26 19:16:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 19:20:27 -0700
commit5b971c7eae5f2049a4725b16a4a44b688d3506b0 (patch)
tree3dba274f008cca389a5d54df1eada9d8efd1670a /tensorflow/python/training
parent51a6118e5bd85935b1d9ec0e68b92f1f98d14982 (diff)
Fix the eval hook to run the correct number of steps when using TPU strategy
PiperOrigin-RevId: 214709465
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py5
-rw-r--r--tensorflow/python/training/evaluation.py68
2 files changed, 68 insertions, 5 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 3bd4bd75bd..1efabcd854 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -344,7 +344,7 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
raise ValueError("steps_per_run should be greater than 0")
self._num_steps = num_steps
self._last_step = last_step
- self._steps_per_run = steps_per_run
+ self._steps_per_run_initial_value = steps_per_run
def begin(self):
self._global_step_tensor = training_util.get_global_step()
@@ -353,7 +353,8 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
self._steps_per_run_variable = get_or_create_steps_per_run_variable()
def _update_steps_per_run_variable(self, global_step, session):
- steps = min(self._last_step - global_step, self._steps_per_run)
+ steps = min(self._last_step - global_step,
+ self._steps_per_run_initial_value)
self._steps_per_run_variable.load(steps, session=session)
def after_create_session(self, session, coord):
diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py
index b36444a14c..2c4eb02d53 100644
--- a/tensorflow/python/training/evaluation.py
+++ b/tensorflow/python/training/evaluation.py
@@ -18,13 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import time
import math
+import time
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -77,6 +78,59 @@ def _get_latest_eval_step_value(update_ops):
return array_ops.identity(_get_or_create_eval_step().read_value())
+class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
+ """Run hook used by the evaluation routines to run the `eval_ops` N times."""
+
+ def __init__(self, num_evals, steps_per_run=1):
+ """Constructs the run hook.
+
+ Args:
+ num_evals: The number of evaluations to run for. if set to None, will
+ iterate the dataset until all inputs are exhausted.
+ steps_per_run: Number of steps executed per run call.
+ """
+ self._num_evals = num_evals
+ self._evals_completed = None
+ self._steps_per_run_initial_value = steps_per_run
+
+ def _set_evals_completed_tensor(self, updated_eval_step):
+ self._evals_completed = updated_eval_step
+
+ def begin(self):
+ self._steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+
+ def after_create_session(self, session, coord):
+ # Update number of steps to run in the first run call
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._steps_per_run_initial_value, self._num_evals)
+ self._steps_per_run_variable.load(steps, session=session)
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs({
+ 'evals_completed': self._evals_completed
+ })
+
+ def after_run(self, run_context, run_values):
+ evals_completed = run_values.results['evals_completed']
+ # Update number of steps to run in the next iteration
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._num_evals - evals_completed,
+ self._steps_per_run_initial_value)
+ self._steps_per_run_variable.load(steps, session=run_context.session)
+
+ if self._num_evals is None:
+ logging.info('Evaluation [%d]', evals_completed)
+ else:
+ logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
+ if self._num_evals is not None and evals_completed >= self._num_evals:
+ run_context.request_stop()
+
+
class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
"""Run hook used by the evaluation routines to run the `eval_ops` N times."""
@@ -176,7 +230,15 @@ def _evaluate_once(checkpoint_path,
hooks = list(hooks or [])
if eval_ops is not None:
- update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
+ if any([isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks]):
+ steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+ update_eval_step = state_ops.assign_add(
+ eval_step,
+ math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype),
+ use_locking=True)
+ else:
+ update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
@@ -188,7 +250,7 @@ def _evaluate_once(checkpoint_path,
eval_step_value = _get_latest_eval_step_value(eval_ops)
for h in hooks:
- if isinstance(h, _StopAfterNEvalsHook):
+ if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)):
h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access
logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',