diff options
author | Katherine Wu <kathywu@google.com> | 2018-08-28 19:00:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 19:04:24 -0700 |
commit | 069f808e5c0462819bcd6c73c75491b00cdd42c2 (patch) | |
tree | ce17bc2f4ec07ccd1baa790a335c192e0ab0fe2b /tensorflow/python/estimator/model_fn.py | |
parent | 4a83b950aa2b5be238bed118acb006a3fe1c806e (diff) |
Export Keras model to SavedModel.
PiperOrigin-RevId: 210648154
Diffstat (limited to 'tensorflow/python/estimator/model_fn.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 007970bef7..ea137a08cc 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -141,7 +141,7 @@ class EstimatorSpec( prediction. predictions: Predictions `Tensor` or dict of `Tensor`. loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`. - train_op: Op to run one training step. + train_op: Op for the training step. eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated @@ -449,3 +449,44 @@ def _check_is_tensor(x, tensor_name): if not isinstance(x, ops.Tensor): raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x)) return x + + +def export_outputs_for_mode( + mode, serving_export_outputs=None, predictions=None, loss=None, + metrics=None): + """Util function for constructing a `ExportOutput` dict given a mode. + + The returned dict can be directly passed to `build_all_signature_defs` helper + function as the `export_outputs` argument, used for generating a SignatureDef + map. + + Args: + mode: A `ModeKeys` specifying the mode. + serving_export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict or None. + predictions: A dict of Tensors or single Tensor representing model + predictions. This argument is only used if serving_export_outputs is not + set. + loss: A dict of Tensors or single Tensor representing calculated loss. + metrics: A dict of (metric_value, update_op) tuples, or a single tuple. + metric_value must be a Tensor, and update_op must be a Tensor or Op + + Returns: + Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object + The key is the expected SignatureDef key for the mode. + + Raises: + ValueError: if an appropriate ExportOutput cannot be found for the mode. + """ + # TODO(b/113185250): move all model export helper functions into an util file. + if mode == ModeKeys.PREDICT: + return _get_export_outputs(serving_export_outputs, predictions) + elif mode == ModeKeys.TRAIN: + return {mode: export_output_lib.TrainOutput( + loss=loss, predictions=predictions, metrics=metrics)} + elif mode == ModeKeys.EVAL: + return {mode: export_output_lib.EvalOutput( + loss=loss, predictions=predictions, metrics=metrics)} + else: + raise ValueError( + 'Export output type not found for mode: {}'.format(mode)) |