aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-08-31 12:07:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 12:12:17 -0700
commitcda5ea80b86909fd20ff8a0f5ba914c5c03b876f (patch)
tree9bfaa9eea182a31df1a7210eeae7f066d491483d /tensorflow/python/estimator
parente894ca7c736c58a8e4c71f0c3f1b1f0c327fa924 (diff)
Roll forward of commit 069f808e5c0462819bcd6c73c75491b00cdd42c2 (rolling back rollback cl/210656847).
Fixing reference to _get_export_outputs_for_spec in TFMA (This function was refactored out, so the string has been removed from the list of methods that are copied from core Estimator). *** Original change description *** Automated rollback of commit 069f808e5c0462819bcd6c73c75491b00cdd42c2 PiperOrigin-RevId: 211122893
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/estimator.py47
-rw-r--r--tensorflow/python/estimator/keras.py75
-rw-r--r--tensorflow/python/estimator/model_fn.py43
3 files changed, 100 insertions, 65 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 44a60495d8..f4d4146e28 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -35,7 +35,6 @@ from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
-from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -958,7 +957,12 @@ class Estimator(object):
mode=mode,
config=self.config)
- export_outputs = self._get_export_outputs_for_spec(estimator_spec)
+ export_outputs = model_fn_lib.export_outputs_for_mode(
+ mode=estimator_spec.mode,
+ serving_export_outputs=estimator_spec.export_outputs,
+ predictions=estimator_spec.predictions,
+ loss=estimator_spec.loss,
+ metrics=estimator_spec.eval_metric_ops)
# Build the SignatureDefs from receivers and all outputs
signature_def_map = export_helpers.build_all_signature_defs(
@@ -1015,45 +1019,6 @@ class Estimator(object):
else:
builder.add_meta_graph(**meta_graph_kwargs)
- def _get_export_outputs_for_spec(self, estimator_spec):
- """Given an `EstimatorSpec`, determine what our export outputs should be.
-
- `EstimatorSpecs` contains `export_outputs` that are used for serving, but
- for
- training and eval graphs, we must wrap the tensors of interest in
- appropriate `tf.estimator.export.ExportOutput` objects.
-
- Args:
- estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported.
-
- Returns:
- a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput`
- object.
-
- Raises:
- ValueError: if an appropriate `ExportOutput` cannot be found for the
- passed `EstimatorSpec.mode`
- """
- mode = estimator_spec.mode
- if mode == model_fn_lib.ModeKeys.PREDICT:
- outputs = estimator_spec.export_outputs
- else:
- if mode == model_fn_lib.ModeKeys.TRAIN:
- output_class = export_output.TrainOutput
- elif mode == model_fn_lib.ModeKeys.EVAL:
- output_class = export_output.EvalOutput
- else:
- raise ValueError(
- 'Export output type not found for mode: {}'.format(mode))
-
- export_out = output_class(
- loss=estimator_spec.loss,
- predictions=estimator_spec.predictions,
- metrics=estimator_spec.eval_metric_ops)
- outputs = {mode: export_out}
-
- return outputs
-
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 6361c6acc1..6b2765be82 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -182,10 +182,58 @@ def _clone_and_build_model(mode,
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
input_tensors, target_tensors = _convert_estimator_io_to_keras(
keras_model, features, labels)
- return models.clone_and_build_model(
+
+ compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
+
+ global_step = None
+ if compile_clone:
+ # Set iterations to the global step created by tf.train.create_global_step()
+ # which is automatically run in the estimator framework.
+ global_step = training_util.get_or_create_global_step()
+ K.track_variable(global_step)
+
+ clone = models.clone_and_build_model(
keras_model, input_tensors, target_tensors, custom_objects,
- compile_clone=(mode != model_fn_lib.ModeKeys.PREDICT),
- in_place_reset=(not keras_model._is_graph_network))
+ compile_clone=compile_clone,
+ in_place_reset=(not keras_model._is_graph_network),
+ optimizer_iterations=global_step)
+
+ return clone
+
+
+def _convert_keras_metrics_to_estimator(model):
+ """Convert metrics from a Keras model to ops used by the Estimator framework.
+
+ Args:
+ model: A `tf.keras.Model` object.
+
+ Returns:
+ Dictionary mapping metric names to tuples of (value, update) ops. May return
+ `None` if the model does not contain any metrics.
+ """
+ if not getattr(model, 'metrics', None):
+ return None
+
+ # TODO(psv/fchollet): support stateful metrics
+ eval_metric_ops = {}
+ # When each metric maps to an output
+ if isinstance(model.metrics, dict):
+ for i, output_name in enumerate(model.metrics.keys()):
+ metric_name = model.metrics[output_name]
+ if callable(metric_name):
+ metric_name = metric_name.__name__
+ # When some outputs use the same metric
+ if list(model.metrics.values()).count(metric_name) > 1:
+ metric_name += '_' + output_name
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i - len(model.metrics)])
+ else:
+ for i, metric_name in enumerate(model.metrics):
+ if callable(metric_name):
+ metric_name = metric_name.__name__
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i])
+ return eval_metric_ops
def _create_keras_model_fn(keras_model, custom_objects=None):
@@ -237,26 +285,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
model._make_test_function() # pylint: disable=protected-access
loss = model.total_loss
- if model.metrics:
- # TODO(psv/fchollet): support stateful metrics
- eval_metric_ops = {}
- # When each metric maps to an output
- if isinstance(model.metrics, dict):
- for i, output_name in enumerate(model.metrics.keys()):
- metric_name = model.metrics[output_name]
- if callable(metric_name):
- metric_name = metric_name.__name__
- # When some outputs use the same metric
- if list(model.metrics.values()).count(metric_name) > 1:
- metric_name += '_' + output_name
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i - len(model.metrics)])
- else:
- for i, metric_name in enumerate(model.metrics):
- if callable(metric_name):
- metric_name = metric_name.__name__
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i])
+ eval_metric_ops = _convert_keras_metrics_to_estimator(model)
# Set train_op only during train.
if mode is model_fn_lib.ModeKeys.TRAIN:
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index fd2787aeaf..439cc2e3a4 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -142,7 +142,7 @@ class EstimatorSpec(
prediction.
predictions: Predictions `Tensor` or dict of `Tensor`.
loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
- train_op: Op to run one training step.
+ train_op: Op for the training step.
eval_metric_ops: Dict of metric results keyed by name.
The values of the dict can be one of the following:
(1) instance of `Metric` class.
@@ -475,3 +475,44 @@ def _check_is_tensor(x, tensor_name):
if not isinstance(x, ops.Tensor):
raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))
return x
+
+
+def export_outputs_for_mode(
+ mode, serving_export_outputs=None, predictions=None, loss=None,
+ metrics=None):
+ """Util function for constructing a `ExportOutput` dict given a mode.
+
+ The returned dict can be directly passed to `build_all_signature_defs` helper
+ function as the `export_outputs` argument, used for generating a SignatureDef
+ map.
+
+ Args:
+ mode: A `ModeKeys` specifying the mode.
+ serving_export_outputs: Describes the output signatures to be exported to
+ `SavedModel` and used during serving. Should be a dict or None.
+ predictions: A dict of Tensors or single Tensor representing model
+ predictions. This argument is only used if serving_export_outputs is not
+ set.
+ loss: A dict of Tensors or single Tensor representing calculated loss.
+ metrics: A dict of (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op
+
+ Returns:
+ Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object
+ The key is the expected SignatureDef key for the mode.
+
+ Raises:
+ ValueError: if an appropriate ExportOutput cannot be found for the mode.
+ """
+ # TODO(b/113185250): move all model export helper functions into an util file.
+ if mode == ModeKeys.PREDICT:
+ return _get_export_outputs(serving_export_outputs, predictions)
+ elif mode == ModeKeys.TRAIN:
+ return {mode: export_output_lib.TrainOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ elif mode == ModeKeys.EVAL:
+ return {mode: export_output_lib.EvalOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
+ else:
+ raise ValueError(
+ 'Export output type not found for mode: {}'.format(mode))