diff options
author | 2018-07-11 12:06:01 -0700 | |
---|---|---|
committer | 2018-07-11 12:09:13 -0700 | |
commit | ae1056ea22c8462668e168741fae1b456c9155d9 (patch) | |
tree | e1355e51d447f9a195340d9f3077a80c1abc2944 | |
parent | 2e9e45201adcd65634ee88f58309002c4fdd95e6 (diff) |
Always append ExamplesPerSecond hook in TPUEstimator
PiperOrigin-RevId: 204164915
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 3ab2a00ba2..8a137005b6 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2298,10 +2298,20 @@ class TPUEstimator(estimator_lib.Estimator): # Clear the bit. self._is_input_fn_invoked = None + # examples_hook is added to training_hooks for both CPU and TPU + # execution. + examples_hook = ExamplesPerSecondHook( + ctx.global_batch_size, + output_dir=self.model_dir, + every_n_steps=self._log_every_n_steps) + if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu( + estimator_spec = model_fn_wrapper.call_without_tpu( features, labels, is_export_mode=is_export_mode) + estimator_spec = estimator_spec._replace( + training_hooks=estimator_spec.training_hooks + (examples_hook,)) + return estimator_spec assert labels is None, '`labels` passed to `model_fn` must be `None`.' # TPUEstimator._call_input_fn passes `input_fn` as features to here. @@ -2369,10 +2379,6 @@ class TPUEstimator(estimator_lib.Estimator): }, every_n_iter=logging_hook_frequency) ]) - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - output_dir=self.model_dir, - every_n_steps=self._log_every_n_steps) examples_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) hooks.append(examples_hook) |