aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn.py
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-08-28 19:00:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 19:04:24 -0700
commit069f808e5c0462819bcd6c73c75491b00cdd42c2 (patch)
treece17bc2f4ec07ccd1baa790a335c192e0ab0fe2b /tensorflow/python/estimator/model_fn.py
parent4a83b950aa2b5be238bed118acb006a3fe1c806e (diff)
Export Keras model to SavedModel.
PiperOrigin-RevId: 210648154
Diffstat (limited to 'tensorflow/python/estimator/model_fn.py')
-rw-r--r--tensorflow/python/estimator/model_fn.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 007970bef7..ea137a08cc 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -141,7 +141,7 @@ class EstimatorSpec(
prediction.
predictions: Predictions `Tensor` or dict of `Tensor`.
loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
- train_op: Op to run one training step.
+ train_op: Op for the training step.
eval_metric_ops: Dict of metric results keyed by name. The values of the
dict are the results of calling a metric function, namely a
`(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
@@ -449,3 +449,44 @@ def _check_is_tensor(x, tensor_name):
if not isinstance(x, ops.Tensor):
raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))
return x
+
+
+def export_outputs_for_mode(
+ mode, serving_export_outputs=None, predictions=None, loss=None,
+ metrics=None):
+ """Util function for constructing a `ExportOutput` dict given a mode.
+
+ The returned dict can be directly passed to `build_all_signature_defs` helper
+ function as the `export_outputs` argument, used for generating a SignatureDef
+ map.
+
+ Args:
+ mode: A `ModeKeys` specifying the mode.
+ serving_export_outputs: Describes the output signatures to be exported to
+ `SavedModel` and used during serving. Should be a dict or None.
+ predictions: A dict of Tensors or single Tensor representing model
+ predictions. This argument is only used if serving_export_outputs is not
+ set.
+ loss: A dict of Tensors or single Tensor representing calculated loss.
+ metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op
+
+ Returns:
+ Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
+ The key is the expected SignatureDef key for the mode.
+
+ Raises:
+ ValueError: if an appropriate ExportOutput cannot be found for the mode.
+ """
+ # TODO(b/113185250): move all model export helper functions into an util file.
+ if mode == ModeKeys.PREDICT:
+ return _get_export_outputs(serving_export_outputs, predictions)
+ elif mode == ModeKeys.TRAIN:
+ return {mode: export_output_lib.TrainOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ elif mode == ModeKeys.EVAL:
+ return {mode: export_output_lib.EvalOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ else:
+ raise ValueError(
+ 'Export output type not found for mode: {}'.format(mode))