diff options
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 47 |
1 files changed, 6 insertions, 41 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 44a60495d8..f4d4146e28 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -35,7 +35,6 @@ from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers -from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -958,7 +957,12 @@ class Estimator(object): mode=mode, config=self.config) - export_outputs = self._get_export_outputs_for_spec(estimator_spec) + export_outputs = model_fn_lib.export_outputs_for_mode( + mode=estimator_spec.mode, + serving_export_outputs=estimator_spec.export_outputs, + predictions=estimator_spec.predictions, + loss=estimator_spec.loss, + metrics=estimator_spec.eval_metric_ops) # Build the SignatureDefs from receivers and all outputs signature_def_map = export_helpers.build_all_signature_defs( @@ -1015,45 +1019,6 @@ class Estimator(object): else: builder.add_meta_graph(**meta_graph_kwargs) - def _get_export_outputs_for_spec(self, estimator_spec): - """Given an `EstimatorSpec`, determine what our export outputs should be. - - `EstimatorSpecs` contains `export_outputs` that are used for serving, but - for - training and eval graphs, we must wrap the tensors of interest in - appropriate `tf.estimator.export.ExportOutput` objects. - - Args: - estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported. - - Returns: - a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput` - object. - - Raises: - ValueError: if an appropriate `ExportOutput` cannot be found for the - passed `EstimatorSpec.mode` - """ - mode = estimator_spec.mode - if mode == model_fn_lib.ModeKeys.PREDICT: - outputs = estimator_spec.export_outputs - else: - if mode == model_fn_lib.ModeKeys.TRAIN: - output_class = export_output.TrainOutput - elif mode == model_fn_lib.ModeKeys.EVAL: - output_class = export_output.EvalOutput - else: - raise ValueError( - 'Export output type not found for mode: {}'.format(mode)) - - export_out = output_class( - loss=estimator_spec.loss, - predictions=estimator_spec.predictions, - metrics=estimator_spec.eval_metric_ops) - outputs = {mode: export_out} - - return outputs - def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" result = self._call_input_fn(input_fn, mode) |