aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn.py
diff options
context:
space:
mode:
authorGravatar Karmel Allison <karmel@google.com>2018-05-04 16:01:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-05 08:30:01 -0700
commit008a3b69a601dc68fd940eb8a03b0c445714a339 (patch)
treedf7a92de37594adc3d8a3aef72baea1ea137fb1c /tensorflow/python/estimator/model_fn.py
parentab48fb528221152299fb08da8116d2eca54b8423 (diff)
Add the ability to export separate SavedModels for train and eval mode to Estimator with two new methods, available in tf.contrib: export_all_saved_models and export_saved_model_for_mode.
PiperOrigin-RevId: 195485922
Diffstat (limited to 'tensorflow/python/estimator/model_fn.py')
-rw-r--r--tensorflow/python/estimator/model_fn.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 8111ab564c..4ab2578769 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
@@ -53,6 +54,13 @@ class ModeKeys(object):
LOSS_METRIC_KEY = 'loss'
AVERAGE_LOSS_METRIC_KEY = 'average_loss'
+# Mapping of the modes to appropriate tag_constants that are used for saving.
+EXPORT_TAG_MAP = {
+ ModeKeys.PREDICT: [tag_constants.SERVING],
+ ModeKeys.TRAIN: [tag_constants.TRAINING],
+ ModeKeys.EVAL: [tag_constants.EVAL],
+}
+
@tf_export('estimator.EstimatorSpec')
class EstimatorSpec(