diff options
author | Katherine Wu <kathywu@google.com> | 2018-08-31 12:07:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 12:12:17 -0700 |
commit | cda5ea80b86909fd20ff8a0f5ba914c5c03b876f (patch) | |
tree | 9bfaa9eea182a31df1a7210eeae7f066d491483d /tensorflow/python/estimator | |
parent | e894ca7c736c58a8e4c71f0c3f1b1f0c327fa924 (diff) |
Roll forward of commit 069f808e5c0462819bcd6c73c75491b00cdd42c2 (rolling back rollback cl/210656847).
Fixing reference to _get_export_outputs_for_spec in TFMA (This function was refactored out, so the string has been removed from the list of methods that are copied from core Estimator).
*** Original change description ***
Automated rollback of commit 069f808e5c0462819bcd6c73c75491b00cdd42c2
PiperOrigin-RevId: 211122893
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 47 | ||||
-rw-r--r-- | tensorflow/python/estimator/keras.py | 75 | ||||
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 43 |
3 files changed, 100 insertions, 65 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 44a60495d8..f4d4146e28 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -35,7 +35,6 @@ from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers -from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -958,7 +957,12 @@ class Estimator(object): mode=mode, config=self.config) - export_outputs = self._get_export_outputs_for_spec(estimator_spec) + export_outputs = model_fn_lib.export_outputs_for_mode( + mode=estimator_spec.mode, + serving_export_outputs=estimator_spec.export_outputs, + predictions=estimator_spec.predictions, + loss=estimator_spec.loss, + metrics=estimator_spec.eval_metric_ops) # Build the SignatureDefs from receivers and all outputs signature_def_map = export_helpers.build_all_signature_defs( @@ -1015,45 +1019,6 @@ class Estimator(object): else: builder.add_meta_graph(**meta_graph_kwargs) - def _get_export_outputs_for_spec(self, estimator_spec): - """Given an `EstimatorSpec`, determine what our export outputs should be. - - `EstimatorSpecs` contains `export_outputs` that are used for serving, but - for - training and eval graphs, we must wrap the tensors of interest in - appropriate `tf.estimator.export.ExportOutput` objects. - - Args: - estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported. - - Returns: - a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput` - object. - - Raises: - ValueError: if an appropriate `ExportOutput` cannot be found for the - passed `EstimatorSpec.mode` - """ - mode = estimator_spec.mode - if mode == model_fn_lib.ModeKeys.PREDICT: - outputs = estimator_spec.export_outputs - else: - if mode == model_fn_lib.ModeKeys.TRAIN: - output_class = export_output.TrainOutput - elif mode == model_fn_lib.ModeKeys.EVAL: - output_class = export_output.EvalOutput - else: - raise ValueError( - 'Export output type not found for mode: {}'.format(mode)) - - export_out = output_class( - loss=estimator_spec.loss, - predictions=estimator_spec.predictions, - metrics=estimator_spec.eval_metric_ops) - outputs = {mode: export_out} - - return outputs - def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" result = self._call_input_fn(input_fn, mode) diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 6361c6acc1..6b2765be82 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -182,10 +182,58 @@ def _clone_and_build_model(mode, K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) input_tensors, target_tensors = _convert_estimator_io_to_keras( keras_model, features, labels) - return models.clone_and_build_model( + + compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT) + + global_step = None + if compile_clone: + # Set iterations to the global step created by tf.train.create_global_step() + # which is automatically run in the estimator framework. + global_step = training_util.get_or_create_global_step() + K.track_variable(global_step) + + clone = models.clone_and_build_model( keras_model, input_tensors, target_tensors, custom_objects, - compile_clone=(mode != model_fn_lib.ModeKeys.PREDICT), - in_place_reset=(not keras_model._is_graph_network)) + compile_clone=compile_clone, + in_place_reset=(not keras_model._is_graph_network), + optimizer_iterations=global_step) + + return clone + + +def _convert_keras_metrics_to_estimator(model): + """Convert metrics from a Keras model to ops used by the Estimator framework. + + Args: + model: A `tf.keras.Model` object. + + Returns: + Dictionary mapping metric names to tuples of (value, update) ops. May return + `None` if the model does not contain any metrics. + """ + if not getattr(model, 'metrics', None): + return None + + # TODO(psv/fchollet): support stateful metrics + eval_metric_ops = {} + # When each metric maps to an output + if isinstance(model.metrics, dict): + for i, output_name in enumerate(model.metrics.keys()): + metric_name = model.metrics[output_name] + if callable(metric_name): + metric_name = metric_name.__name__ + # When some outputs use the same metric + if list(model.metrics.values()).count(metric_name) > 1: + metric_name += '_' + output_name + eval_metric_ops[metric_name] = metrics_module.mean( + model.metrics_tensors[i - len(model.metrics)]) + else: + for i, metric_name in enumerate(model.metrics): + if callable(metric_name): + metric_name = metric_name.__name__ + eval_metric_ops[metric_name] = metrics_module.mean( + model.metrics_tensors[i]) + return eval_metric_ops def _create_keras_model_fn(keras_model, custom_objects=None): @@ -237,26 +285,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None): model._make_test_function() # pylint: disable=protected-access loss = model.total_loss - if model.metrics: - # TODO(psv/fchollet): support stateful metrics - eval_metric_ops = {} - # When each metric maps to an output - if isinstance(model.metrics, dict): - for i, output_name in enumerate(model.metrics.keys()): - metric_name = model.metrics[output_name] - if callable(metric_name): - metric_name = metric_name.__name__ - # When some outputs use the same metric - if list(model.metrics.values()).count(metric_name) > 1: - metric_name += '_' + output_name - eval_metric_ops[metric_name] = metrics_module.mean( - model.metrics_tensors[i - len(model.metrics)]) - else: - for i, metric_name in enumerate(model.metrics): - if callable(metric_name): - metric_name = metric_name.__name__ - eval_metric_ops[metric_name] = metrics_module.mean( - model.metrics_tensors[i]) + eval_metric_ops = _convert_keras_metrics_to_estimator(model) # Set train_op only during train. if mode is model_fn_lib.ModeKeys.TRAIN: diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index fd2787aeaf..439cc2e3a4 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -142,7 +142,7 @@ class EstimatorSpec( prediction. predictions: Predictions `Tensor` or dict of `Tensor`. loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`. - train_op: Op to run one training step. + train_op: Op for the training step. eval_metric_ops: Dict of metric results keyed by name. The values of the dict can be one of the following: (1) instance of `Metric` class. @@ -475,3 +475,44 @@ def _check_is_tensor(x, tensor_name): if not isinstance(x, ops.Tensor): raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x)) return x + + +def export_outputs_for_mode( + mode, serving_export_outputs=None, predictions=None, loss=None, + metrics=None): + """Util function for constructing a `ExportOutput` dict given a mode. + + The returned dict can be directly passed to `build_all_signature_defs` helper + function as the `export_outputs` argument, used for generating a SignatureDef + map. + + Args: + mode: A `ModeKeys` specifying the mode. + serving_export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict or None. + predictions: A dict of Tensors or single Tensor representing model + predictions. This argument is only used if serving_export_outputs is not + set. + loss: A dict of Tensors or single Tensor representing calculated loss. + metrics: A dict of (metric_value, update_op) tuples, or a single tuple. + metric_value must be a Tensor, and update_op must be a Tensor or Op + + Returns: + Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object + The key is the expected SignatureDef key for the mode. + + Raises: + ValueError: if an appropriate ExportOutput cannot be found for the mode. + """ + # TODO(b/113185250): move all model export helper functions into an util file. + if mode == ModeKeys.PREDICT: + return _get_export_outputs(serving_export_outputs, predictions) + elif mode == ModeKeys.TRAIN: + return {mode: export_output_lib.TrainOutput( + loss=loss, predictions=predictions, metrics=metrics)} + elif mode == ModeKeys.EVAL: + return {mode: export_output_lib.EvalOutput( + loss=loss, predictions=predictions, metrics=metrics)} + else: + raise ValueError( + 'Export output type not found for mode: {}'.format(mode)) |