aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r--tensorflow/python/estimator/estimator.py47
1 files changed, 6 insertions, 41 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 44a60495d8..f4d4146e28 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -35,7 +35,6 @@ from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
-from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -958,7 +957,12 @@ class Estimator(object):
mode=mode,
config=self.config)
- export_outputs = self._get_export_outputs_for_spec(estimator_spec)
+ export_outputs = model_fn_lib.export_outputs_for_mode(
+ mode=estimator_spec.mode,
+ serving_export_outputs=estimator_spec.export_outputs,
+ predictions=estimator_spec.predictions,
+ loss=estimator_spec.loss,
+ metrics=estimator_spec.eval_metric_ops)
# Build the SignatureDefs from receivers and all outputs
signature_def_map = export_helpers.build_all_signature_defs(
@@ -1015,45 +1019,6 @@ class Estimator(object):
else:
builder.add_meta_graph(**meta_graph_kwargs)
- def _get_export_outputs_for_spec(self, estimator_spec):
- """Given an `EstimatorSpec`, determine what our export outputs should be.
-
- `EstimatorSpecs` contains `export_outputs` that are used for serving, but
- for
- training and eval graphs, we must wrap the tensors of interest in
- appropriate `tf.estimator.export.ExportOutput` objects.
-
- Args:
- estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported.
-
- Returns:
- a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput`
- object.
-
- Raises:
- ValueError: if an appropriate `ExportOutput` cannot be found for the
- passed `EstimatorSpec.mode`
- """
- mode = estimator_spec.mode
- if mode == model_fn_lib.ModeKeys.PREDICT:
- outputs = estimator_spec.export_outputs
- else:
- if mode == model_fn_lib.ModeKeys.TRAIN:
- output_class = export_output.TrainOutput
- elif mode == model_fn_lib.ModeKeys.EVAL:
- output_class = export_output.EvalOutput
- else:
- raise ValueError(
- 'Export output type not found for mode: {}'.format(mode))
-
- export_out = output_class(
- loss=estimator_spec.loss,
- predictions=estimator_spec.predictions,
- metrics=estimator_spec.eval_metric_ops)
- outputs = {mode: export_out}
-
- return outputs
-
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)