diff options
author | 2018-05-04 16:01:02 -0700 | |
---|---|---|
committer | 2018-05-05 08:30:01 -0700 | |
commit | 008a3b69a601dc68fd940eb8a03b0c445714a339 (patch) | |
tree | df7a92de37594adc3d8a3aef72baea1ea137fb1c /tensorflow/python/estimator/model_fn.py | |
parent | ab48fb528221152299fb08da8116d2eca54b8423 (diff) |
Add the ability to export separate SavedModels for train and eval mode to Estimator with two new methods, available in tf.contrib: export_all_saved_models and export_saved_model_for_mode.
PiperOrigin-RevId: 195485922
Diffstat (limited to 'tensorflow/python/estimator/model_fn.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 8111ab564c..4ab2578769 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest @@ -53,6 +54,13 @@ class ModeKeys(object): LOSS_METRIC_KEY = 'loss' AVERAGE_LOSS_METRIC_KEY = 'average_loss' +# Mapping of the modes to appropriate tag_constants that are used for saving. +EXPORT_TAG_MAP = { + ModeKeys.PREDICT: [tag_constants.SERVING], + ModeKeys.TRAIN: [tag_constants.TRAINING], + ModeKeys.EVAL: [tag_constants.EVAL], +} + @tf_export('estimator.EstimatorSpec') class EstimatorSpec( |