aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yanan Cao <ycao@google.com>2018-07-11 12:06:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 12:09:13 -0700
commitae1056ea22c8462668e168741fae1b456c9155d9 (patch)
treee1355e51d447f9a195340d9f3077a80c1abc2944
parent2e9e45201adcd65634ee88f58309002c4fdd95e6 (diff)
Always append ExamplesPerSecond hook in TPUEstimator
PiperOrigin-RevId: 204164915
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 3ab2a00ba2..8a137005b6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -2298,10 +2298,20 @@ class TPUEstimator(estimator_lib.Estimator):
# Clear the bit.
self._is_input_fn_invoked = None
+ # examples_hook is added to training_hooks for both CPU and TPU
+ # execution.
+ examples_hook = ExamplesPerSecondHook(
+ ctx.global_batch_size,
+ output_dir=self.model_dir,
+ every_n_steps=self._log_every_n_steps)
+
if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
logging.info('Running %s on CPU', mode)
- return model_fn_wrapper.call_without_tpu(
+ estimator_spec = model_fn_wrapper.call_without_tpu(
features, labels, is_export_mode=is_export_mode)
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks + (examples_hook,))
+ return estimator_spec
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.
@@ -2369,10 +2379,6 @@ class TPUEstimator(estimator_lib.Estimator):
},
every_n_iter=logging_hook_frequency)
])
- examples_hook = ExamplesPerSecondHook(
- ctx.global_batch_size,
- output_dir=self.model_dir,
- every_n_steps=self._log_every_n_steps)
examples_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
hooks.append(examples_hook)