aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-07-26 13:08:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 13:12:53 -0700
commitb5ee7b1bc75928825991957c189ded0b970a1081 (patch)
tree71eff08b3245b1945004a2d69e7fb10ec996746c
parent4fd159cbd0492c21de08197ba4426a2e433ff402 (diff)
PUBLIC: Allow user passing training/evaluation/prediction_hooks from tf.estimator.EstimatorSpec.
PiperOrigin-RevId: 206208119
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py98
-rw-r--r--tensorflow/python/estimator/model_fn.py58
2 files changed, 101 insertions, 55 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 6e7dae6fce..ee9ad525ee 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -258,7 +258,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
eval_metrics=None,
export_outputs=None,
scaffold_fn=None,
- host_call=None):
+ host_call=None,
+ training_hooks=None,
+ evaluation_hooks=None,
+ prediction_hooks=None):
"""Creates a validated `TPUEstimatorSpec` instance."""
host_calls = {}
if eval_metrics is not None:
@@ -266,6 +269,17 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
if host_call is not None:
host_calls['host_call'] = host_call
_OutfeedHostCall.validate(host_calls)
+
+ training_hooks = list(training_hooks or [])
+ evaluation_hooks = list(evaluation_hooks or [])
+ prediction_hooks = list(prediction_hooks or [])
+
+ for hook in training_hooks + evaluation_hooks + prediction_hooks:
+ if not isinstance(hook, session_run_hook.SessionRunHook):
+ raise TypeError(
+ 'All hooks must be SessionRunHook instances, given: {}'.format(
+ hook))
+
return super(TPUEstimatorSpec, cls).__new__(
cls,
mode=mode,
@@ -275,7 +289,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
eval_metrics=eval_metrics,
export_outputs=export_outputs,
scaffold_fn=scaffold_fn,
- host_call=host_call)
+ host_call=host_call,
+ training_hooks=training_hooks,
+ evaluation_hooks=evaluation_hooks,
+ prediction_hooks=prediction_hooks)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
@@ -291,6 +308,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
hooks = None
if self.host_call is not None:
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
+ hooks = list(hooks or [])
scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec(
mode=self.mode,
@@ -300,9 +318,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
eval_metric_ops=eval_metric_ops,
export_outputs=self.export_outputs,
scaffold=scaffold,
- training_hooks=hooks,
- evaluation_hooks=hooks,
- prediction_hooks=hooks)
+ training_hooks=self.training_hooks + hooks,
+ evaluation_hooks=self.evaluation_hooks + hooks,
+ prediction_hooks=self.prediction_hooks + hooks)
class _OpQueueContext(object):
@@ -1220,6 +1238,7 @@ class _ModelFnWrapper(object):
host_call = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
+ captured_training_hooks = _CapturedObject()
def train_step(loss):
"""Training step function for use inside a while loop."""
@@ -1236,6 +1255,8 @@ class _ModelFnWrapper(object):
else:
captured_scaffold_fn.capture(None)
+ captured_training_hooks.capture(estimator_spec.training_hooks)
+
# We must run train_op to update the variables prior to running the
# outfeed.
with ops.control_dependencies([train_op]):
@@ -1247,7 +1268,8 @@ class _ModelFnWrapper(object):
with ops.control_dependencies(host_call_outfeed_ops):
return array_ops.identity(loss)
- return train_step, host_call, captured_scaffold_fn
+ return (train_step, host_call, captured_scaffold_fn,
+ captured_training_hooks)
def convert_to_single_tpu_eval_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single eval step on TPU.
@@ -1277,6 +1299,7 @@ class _ModelFnWrapper(object):
"""
host_calls = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
+ captured_eval_hooks = _CapturedObject()
def eval_step(total_loss):
"""Evaluation step function for use inside a while loop."""
@@ -1291,6 +1314,8 @@ class _ModelFnWrapper(object):
loss = tpu_estimator_spec.loss
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
+ captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks)
+
to_record = {}
if tpu_estimator_spec.eval_metrics:
to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
@@ -1303,7 +1328,7 @@ class _ModelFnWrapper(object):
with ops.control_dependencies(host_calls.create_enqueue_op()):
return math_ops.add(total_loss, loss)
- return eval_step, host_calls, captured_scaffold_fn
+ return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
def convert_to_single_tpu_predict_step(self, dequeue_fn):
"""Converts user provided model_fn` as a single predict step on TPU.
@@ -1318,6 +1343,7 @@ class _ModelFnWrapper(object):
"""
host_calls = _OutfeedHostCall(self._ctx)
captured_scaffold_fn = _CapturedObject()
+ captured_predict_hooks = _CapturedObject()
def predict_step(unused_scalar_stopping_signal):
"""Evaluation step function for use inside a while loop."""
@@ -1338,6 +1364,7 @@ class _ModelFnWrapper(object):
self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
+ captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks)
to_record = {}
identity_fn = lambda **kwargs: kwargs
to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
@@ -1349,7 +1376,8 @@ class _ModelFnWrapper(object):
with ops.control_dependencies(host_calls.create_enqueue_op()):
return _StopSignals.as_scalar_stopping_signal(stopping_signals)
- return predict_step, host_calls, captured_scaffold_fn
+ return (predict_step, host_calls, captured_scaffold_fn,
+ captured_predict_hooks)
def _verify_tpu_spec_predictions(self, predictions):
"""Validates TPUEstimatorSpec.predictions dict."""
@@ -1471,11 +1499,9 @@ class _ModelFnWrapper(object):
err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
if estimator_spec.training_chief_hooks:
- raise ValueError(err_msg.format('training_chief_hooks'))
- if estimator_spec.training_hooks:
- raise ValueError(err_msg.format('training_hooks'))
- if estimator_spec.evaluation_hooks:
- raise ValueError(err_msg.format('evaluation_hooks'))
+ raise ValueError(
+ err_msg.format('training_chief_hooks') + 'If you want' +
+ ' to pass training hooks, please pass via training_hooks.')
if estimator_spec.scaffold:
logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
@@ -1957,10 +1983,9 @@ class TPUEstimator(estimator_lib.Estimator):
"""Constructs an `TPUEstimator` instance.
Args:
- model_fn: Model function as required by `Estimator`. For training, the
- returned `EstimatorSpec` cannot have hooks as it is not supported in
- `TPUEstimator`. Instead, the user can pass the training hooks as
- an argument to `TPUEstimator.train()`.
+ model_fn: Model function as required by `Estimator` which returns
+ EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
+ and `prediction_hooks` must not capure any TPU Tensor inside the model_fn.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
@@ -2429,7 +2454,7 @@ class TPUEstimator(estimator_lib.Estimator):
graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
if mode == model_fn_lib.ModeKeys.TRAIN:
- loss, host_call, scaffold = (
+ loss, host_call, scaffold, training_hooks = (
_train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
host_ops = host_call.create_tpu_hostcall()
if host_ops is None:
@@ -2484,6 +2509,9 @@ class TPUEstimator(estimator_lib.Estimator):
self._config.tpu_config.iterations_per_loop)
hooks.append(examples_hook)
+ if training_hooks:
+ hooks.extend(training_hooks)
+
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
@@ -2495,6 +2523,7 @@ class TPUEstimator(estimator_lib.Estimator):
checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
chief_hooks.append(checkpoint_hook)
+
summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with ops.control_dependencies([loss]):
update_ops = _sync_variables_ops()
@@ -2514,7 +2543,7 @@ class TPUEstimator(estimator_lib.Estimator):
scaffold=scaffold)
if mode == model_fn_lib.ModeKeys.EVAL:
- total_loss, host_calls, scaffold = _eval_on_tpu_system(
+ total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
iterations_per_loop_var = _create_or_get_iterations_per_loop()
mean_loss = math_ops.div(total_loss,
@@ -2558,6 +2587,9 @@ class TPUEstimator(estimator_lib.Estimator):
rendezvous=self._rendezvous[mode]),
] + input_hooks
+ if eval_hooks:
+ hooks.extend(eval_hooks)
+
return model_fn_lib.EstimatorSpec(
mode,
loss=mean_loss,
@@ -2568,8 +2600,9 @@ class TPUEstimator(estimator_lib.Estimator):
# Predict
assert mode == model_fn_lib.ModeKeys.PREDICT
- dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system(
- ctx, model_fn_wrapper, dequeue_fn)
+ (dummy_predict_op, host_calls,
+ scaffold, prediction_hooks) = _predict_on_tpu_system(
+ ctx, model_fn_wrapper, dequeue_fn)
with ops.control_dependencies([dummy_predict_op]):
internal_ops_to_run = _sync_variables_ops()
with ops.control_dependencies(internal_ops_to_run):
@@ -2625,6 +2658,9 @@ class TPUEstimator(estimator_lib.Estimator):
ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]),
] + input_hooks
+ if prediction_hooks:
+ hooks.extend(prediction_hooks)
+
return model_fn_lib.EstimatorSpec(
mode,
prediction_hooks=hooks,
@@ -2708,8 +2744,8 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- single_tpu_eval_step, host_calls, captured_scaffold_fn = (
- model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))
+ (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
+ ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)
def multi_tpu_eval_steps_on_single_shard():
return training_loop.repeat(
@@ -2724,15 +2760,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
device_assignment=ctx.device_assignment)
scaffold = _get_scaffold(captured_scaffold_fn)
- return loss, host_calls, scaffold
+ return loss, host_calls, scaffold, captured_eval_hooks.get()
def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- single_tpu_train_step, host_call, captured_scaffold_fn = (
- model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
+ (single_tpu_train_step, host_call, captured_scaffold_fn,
+ captured_training_hooks) = (
+ model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
def multi_tpu_train_steps_on_single_shard():
return training_loop.repeat(
@@ -2747,15 +2784,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
device_assignment=ctx.device_assignment)
scaffold = _get_scaffold(captured_scaffold_fn)
- return loss, host_call, scaffold
+ return loss, host_call, scaffold, captured_training_hooks.get()
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
num_cores = ctx.num_cores
- single_tpu_predict_step, host_calls, captured_scaffold_fn = (
- model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn))
+ (single_tpu_predict_step, host_calls, captured_scaffold_fn,
+ captured_predict_hooks
+ ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
def multi_tpu_predict_steps_on_single_shard():
@@ -2775,7 +2813,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
outputs_from_all_shards=False)
scaffold = _get_scaffold(captured_scaffold_fn)
- return dummy_predict_op, host_calls, scaffold
+ return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get()
def _wrap_computation_in_while_loop(device, op_fn):
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index a9fd8f8e1a..9db9ccd01d 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -380,15 +380,12 @@ def _maybe_add_default_serving_output(export_outputs):
return export_outputs
-class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
- 'mode',
- 'predictions',
- 'loss',
- 'train_op',
- 'eval_metrics',
- 'export_outputs',
- 'scaffold_fn',
- 'host_call'])):
+class _TPUEstimatorSpec(
+ collections.namedtuple('TPUEstimatorSpec', [
+ 'mode', 'predictions', 'loss', 'train_op', 'eval_metrics',
+ 'export_outputs', 'scaffold_fn', 'host_call', 'training_hooks',
+ 'evaluation_hooks', 'prediction_hooks'
+ ])):
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See
@@ -404,17 +401,24 @@ class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
eval_metrics=None,
export_outputs=None,
scaffold_fn=None,
- host_call=None):
+ host_call=None,
+ training_hooks=None,
+ evaluation_hooks=None,
+ prediction_hooks=None):
"""Creates a `_TPUEstimatorSpec` instance."""
- return super(_TPUEstimatorSpec, cls).__new__(cls,
- mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op,
- eval_metrics=eval_metrics,
- export_outputs=export_outputs,
- scaffold_fn=scaffold_fn,
- host_call=host_call)
+ return super(_TPUEstimatorSpec, cls).__new__(
+ cls,
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metrics=eval_metrics,
+ export_outputs=export_outputs,
+ scaffold_fn=scaffold_fn,
+ host_call=host_call,
+ training_hooks=training_hooks,
+ evaluation_hooks=evaluation_hooks,
+ prediction_hooks=prediction_hooks)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
@@ -423,12 +427,16 @@ class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
else:
metric_fn, tensors = self.eval_metrics
eval_metric_ops = metric_fn(**tensors)
- return EstimatorSpec(mode=self.mode,
- predictions=self.predictions,
- loss=self.loss,
- train_op=self.train_op,
- eval_metric_ops=eval_metric_ops,
- export_outputs=self.export_outputs)
+ return EstimatorSpec(
+ mode=self.mode,
+ predictions=self.predictions,
+ loss=self.loss,
+ train_op=self.train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=self.export_outputs,
+ training_hooks=self.training_hooks,
+ evaluation_hooks=self.evaluation_hooks,
+ prediction_hooks=self.prediction_hooks)
def _check_is_tensor_or_operation(x, name):