diff options
10 files changed, 481 insertions, 280 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index ef5a16d7dc..dc9aecff71 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -275,12 +275,12 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import BaseEstim from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn -from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey +from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModeKeys from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index 7ca694f299..f8c0a6fe5d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import monitors as monitor_lib from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.contrib.losses.python.losses import loss_ops from tensorflow.python import summary @@ -236,7 +237,7 @@ def _dnn_classifier_model_fn(features, labels, mode, params): activation_fn=activation_fn, variables_collections=[parent_scope], scope=scope) - if dropout is not None and mode == estimator.ModeKeys.TRAIN: + if dropout is not None and mode == model_fn.ModeKeys.TRAIN: net = layers.dropout( net, keep_prob=(1.0 - dropout)) @@ -257,7 +258,7 @@ def _dnn_classifier_model_fn(features, labels, mode, params): if enable_centered_bias: logits = nn.bias_add(logits, _centered_bias(num_label_columns)) - if mode == estimator.ModeKeys.TRAIN: + if mode == model_fn.ModeKeys.TRAIN: labels = _reshape_labels(labels) weights = _get_weight_tensor(features, weight_column_name) training_loss = loss_fn(logits, labels, weights=weights) @@ -279,7 +280,7 @@ def _dnn_classifier_model_fn(features, labels, mode, params): return None, loss, control_flow_ops.group(*train_ops) - elif mode == estimator.ModeKeys.EVAL: + elif mode == model_fn.ModeKeys.EVAL: predictions = _predictions(logits=logits, n_classes=n_classes) labels = _reshape_labels(labels) @@ -524,6 +525,12 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable): return (pred[_PROBABILITIES] for pred in preds) return preds[_PROBABILITIES] + def _get_predict_ops(self, features): + """See `Estimator` class.""" + # This method exists to support some models that use the legacy interface. + # pylint: disable=protected-access + return self._estimator._get_predict_ops(features) + def get_variable_names(self): """Returns list of all variable names in this model. diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 9bb2d353ec..0a9d0a3d1c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import composable_model from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.python.framework import ops @@ -228,11 +229,10 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): with ops.get_default_graph().colocate_with(global_step): return state_ops.assign_add(global_step, 1).op - model_fn_ops = self._head.head_ops(features, labels, - estimator.ModeKeys.TRAIN, - _make_training_op, - logits=logits) - return model_fn_ops.training_op, model_fn_ops.loss + return self._head.head_ops(features, labels, + model_fn.ModeKeys.TRAIN, + _make_training_op, + logits=logits) def _get_eval_ops(self, features, labels, metrics=None): """See base class.""" @@ -240,32 +240,30 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): features, labels = self._feature_engineering_fn(features, labels) logits = self._logits(features) - model_fn_ops = self._head.head_ops(features, labels, - estimator.ModeKeys.EVAL, None, - logits=logits) - all_metrics = model_fn_ops.default_metrics + eval_ops = self._head.head_ops(features, labels, + model_fn.ModeKeys.EVAL, None, + logits=logits) + custom_metrics = {} if metrics: for name, metric in six.iteritems(metrics): if not isinstance(name, tuple): # TODO(zakaria): remove once deprecation is finished (b/31229024) - all_metrics[(name, self._default_prediction_key)] = metric + custom_metrics[(name, self._default_prediction_key)] = metric else: - all_metrics[name] = metric + custom_metrics[name] = metric # TODO(zakaria): Remove this once we refactor this class to delegate # to estimator. - # pylint: disable=protected-access - result = estimator._make_metrics_ops(all_metrics, features, labels, - model_fn_ops.predictions) - return result + eval_ops.eval_metric_ops.update(estimator._make_metrics_ops( # pylint: disable=protected-access + custom_metrics, features, labels, eval_ops.predictions)) + return eval_ops def _get_predict_ops(self, features): """See base class.""" features = self._get_feature_dict(features) features, _ = self._feature_engineering_fn(features, None) logits = self._logits(features) - model_fn_ops = self._head.head_ops(features, None, estimator.ModeKeys.INFER, - None, logits=logits) - return model_fn_ops.predictions + return self._head.head_ops(features, None, model_fn.ModeKeys.INFER, + None, logits=logits) @deprecated( "2016-09-23", @@ -458,7 +456,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params): activation_fn=dnn_activation_fn, variables_collections=[dnn_parent_scope], scope=scope) - if dnn_dropout is not None and mode == estimator.ModeKeys.TRAIN: + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: net = layers.dropout( net, keep_prob=(1.0 - dnn_dropout)) @@ -785,9 +783,9 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable): def _get_predict_ops(self, features): """See `Estimator` class.""" + # This method exists to support some models that use the legacy interface. # pylint: disable=protected-access - return self._estimator._get_predict_ops(features)[ - prediction_key.PredictionKey.PROBABILITIES] + return self._estimator._get_predict_ops(features) def get_variable_names(self): """Returns list of all variable names in this model. @@ -1041,8 +1039,36 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator): default_prediction_key=prediction_key.PredictionKey.SCORES, enable_centered_bias=enable_centered_bias) - def _get_predict_ops(self, features): - """See base class.""" - return super( - DNNLinearCombinedRegressor, - self)._get_predict_ops(features)[prediction_key.PredictionKey.SCORES] + @deprecated_arg_values( + estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS, + as_iterable=False) + def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True): + """Runs inference to determine the predicted class.""" + key = prediction_key.PredictionKey.SCORES + preds = super(DNNLinearCombinedRegressor, self).predict( + x=x, + input_fn=input_fn, + batch_size=batch_size, + outputs=[key], + as_iterable=as_iterable) + if as_iterable: + return _as_iterable(preds, output=key) + return preds[key] + + def export(self, + export_dir, + input_fn=None, + input_feature_key=None, + use_deprecated_input_fn=True, + signature_fn=None, + default_batch_size=None, + exports_to_keep=None): + return super(DNNLinearCombinedRegressor, self).export( + export_dir=export_dir, + input_fn=input_fn, + input_feature_key=input_feature_key, + use_deprecated_input_fn=use_deprecated_input_fn, + signature_fn=signature_fn, + prediction_key=prediction_key.PredictionKey.SCORES, + default_batch_size=default_batch_size, + exports_to_keep=exports_to_keep) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 953e810bf6..f69ab30f7a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -572,10 +572,13 @@ def _get_dynamic_rnn_model_fn(cell, predict_probabilities) loss = _single_value_loss( rnn_activations, labels, sequence_length, target_column, features) + # TODO(roumposg): Return eval_metric_ops here, instead of default_metrics. default_metrics = _get_default_metrics( problem_type, prediction_type, sequence_length) prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state - training_op = optimizers.optimize_loss( + eval_metric_ops = estimator._make_metrics_ops( # pylint: disable=protected-access + default_metrics, features, labels, prediction_dict) + train_op = optimizers.optimize_loss( loss=loss, global_step=None, learning_rate=learning_rate, @@ -585,8 +588,8 @@ def _get_dynamic_rnn_model_fn(cell, return estimator.ModelFnOps(mode=mode, predictions=prediction_dict, loss=loss, - training_op=training_op, - default_metrics=default_metrics) + train_op=train_op, + eval_metric_ops=eval_metric_ops) return _dynamic_rnn_model_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 3daeb28253..b92ed3ccb0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import abc -import collections import copy import inspect import itertools @@ -37,7 +36,6 @@ from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values from tensorflow.contrib.framework import deprecated_args -from tensorflow.contrib.framework import get_graph_from_inputs from tensorflow.contrib.framework import list_variables from tensorflow.contrib.framework import load_variable from tensorflow.contrib.learn.python.learn import evaluable @@ -47,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import monitors as monitor_lib from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn from tensorflow.contrib.learn.python.learn.estimators import metric_key +from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import tensor_signature from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError @@ -56,8 +55,6 @@ from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_setter @@ -78,81 +75,13 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = ( ' est = Estimator(...) -> est = SKCompat(Estimator(...))') -class ModeKeys(object): - """Standard names for model modes. +# TODO(roumposg): Migrate external users to tf.learn.contrib.ModeKeys and delete +# this. +ModeKeys = model_fn_lib.ModeKeys # pylint: disable=invalid-name - The following standard keys are defined: - * `TRAIN`: training mode. - * `EVAL`: evaluation mode. - * `INFER`: inference mode. - """ - - TRAIN = 'train' - EVAL = 'eval' - INFER = 'infer' - - -class ModelFnOps( - collections.namedtuple('ModelFnOps', ['predictions', 'loss', 'training_op', - 'default_metrics', 'signature_fn'])): - - def __new__(cls, mode, predictions=None, loss=None, training_op=None, - default_metrics=None, signature_fn=None): - # Assert all ops are from the same graph. - get_graph_from_inputs((predictions, loss, training_op)) - - # Validate training_op. - if training_op is None: - if mode == ModeKeys.TRAIN: - raise ValueError('Missing training_op.') - elif not isinstance(training_op, ops.Operation): - # TODO(ptucker): Should this be allowed? Consider raising error. - training_op = ops.convert_to_tensor(training_op).op - - # Validate loss. - if loss is None: - if mode in (ModeKeys.TRAIN, ModeKeys.EVAL): - raise ValueError('Missing loss.') - else: - loss = ops.convert_to_tensor(loss) - loss_shape = loss.get_shape() - if loss_shape.num_elements() not in (None, 1): - raise ValueError('Loss must be scalar: %s.' % loss) - if not loss_shape.is_compatible_with(tensor_shape.scalar()): - loss = array_ops.reshape(loss, []) - - # Validate predictions. - if predictions is None: - if mode == ModeKeys.INFER or mode == ModeKeys.EVAL: - raise ValueError('Missing predictions.') - else: - if isinstance(predictions, dict): - predictions = { - k: contrib_framework.convert_to_tensor_or_sparse_tensor(v) - for k, v in six.iteritems(predictions) - } - else: - predictions = contrib_framework.convert_to_tensor_or_sparse_tensor( - predictions) - - # Validate default_metrics - if default_metrics is None: - default_metrics = {} - else: - if not isinstance(default_metrics, dict): - raise ValueError('default_metrics must be a dict.') - for k, v in default_metrics.items(): - if not isinstance(v, metric_spec.MetricSpec): - raise ValueError('Metric with key=%s is not MetricSpec' % k) - - # validate signature_fn - if signature_fn: - if not callable(signature_fn): - raise ValueError('signature_fn is not callable.') - - return super(ModelFnOps, cls).__new__(cls, predictions, loss, training_op, - default_metrics, signature_fn) +# TODO(roumposg): Migrate external users to model.ModelFnOps and delete this. +ModelFnOps = model_fn_lib.ModelFnOps # pylint: disable=invalid-name def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): @@ -645,7 +574,7 @@ class BaseEstimator( labels: `Tensor` or `dict` of `Tensor` objects. Returns: - Tuple of train `Operation` and loss `Tensor`. + A `ModelFnOps` object. """ pass @@ -657,7 +586,7 @@ class BaseEstimator( features: `Tensor` or `dict` of `Tensor` objects. Returns: - predictions: `Tensor` or `dict` of `Tensor` objects. + A `ModelFnOps` object. """ pass @@ -679,7 +608,7 @@ class BaseEstimator( `../metric_spec.py`. Returns: - metrics: `dict` of `Tensor` objects. + A `ModelFnOps` object. """ raise NotImplementedError('_get_eval_ops not implemented in BaseEstimator') @@ -768,7 +697,23 @@ class BaseEstimator( global_step = contrib_framework.create_global_step(g) features, labels = input_fn() self._check_inputs(features, labels) - train_op, loss_op = self._get_train_ops(features, labels) + + # The default return type of _get_train_ops is ModelFnOps. But there are + # some subclasses of tf.contrib.learn.Estimator which override this + # method and use the legacy signature, namely _get_train_ops returns a + # (train_op, loss) tuple. The following else-statement code covers these + # cases, but will soon be deleted after the subclasses are updated. + # TODO(b/32664904): Update subclasses and delete the else-statement. + train_ops = self._get_train_ops(features, labels) + if isinstance(train_ops, ModelFnOps): # Default signature + train_op = train_ops.train_op + loss_op = train_ops.loss + else: # Legacy signature + if len(train_ops) != 2: + raise ValueError('Expected a tuple of train_op and loss, got {}'. + format(train_ops)) + train_op = train_ops[0] + loss_op = train_ops[1] hooks = monitor_lib.replace_monitors_with_hooks(monitors, self) @@ -845,7 +790,20 @@ class BaseEstimator( global_step = contrib_framework.create_global_step(g) features, labels = input_fn() self._check_inputs(features, labels) - eval_dict = self._get_eval_ops(features, labels, metrics) + + # The default return type of _get_eval_ops is ModelFnOps. But there are + # some subclasses of tf.contrib.learn.Estimator which override this + # method and use the legacy signature, namely _get_eval_ops returns an + # `eval_dict` dictionary of Tensors. The following else-statement code + # covers these cases, but will soon be deleted after the subclasses are + # updated. + # TODO(b/32664904): Update subclasses and delete the else-statement. + eval_ops = self._get_eval_ops(features, labels, metrics) + if isinstance(eval_ops, ModelFnOps): # Default signature + eval_dict = eval_ops.eval_metric_ops + else: # Legacy signature + eval_dict = eval_ops + update_op, eval_dict = self._extract_metric_update_ops(eval_dict) eval_results, current_global_step = graph_actions.evaluate( graph=g, @@ -878,7 +836,20 @@ class BaseEstimator( random_seed.set_random_seed(self._config.tf_random_seed) contrib_framework.create_global_step(g) features = self._get_features_from_input_fn(input_fn) - predictions = self._get_predict_ops(features) + + # The default return type of _get_predict_ops is ModelFnOps. But there are + # some subclasses of tf.contrib.learn.Estimator which override this + # method and use the legacy signature, namely _get_predict_ops returns a + # `predictions` Tensor or dict or Tensors. The following else-statement + # code covers these cases, but will soon be deleted after the subclasses + # are updated. + # TODO(b/32664904): Update subclasses and delete the else-statement. + infer_ops = self._get_predict_ops(features) + if isinstance(infer_ops, ModelFnOps): # Default signature + predictions = infer_ops.predictions + else: # Legacy signature + predictions = infer_ops + # If predictions is single output - wrap it into dict, and remember to # return not a dict. return_dict = isinstance(predictions, dict) @@ -965,11 +936,28 @@ class Estimator(BaseEstimator): config=None, params=None, feature_engineering_fn=None): - """Constructs an Estimator instance. + """Constructs an `Estimator` instance. Args: - model_fn: Model function, takes features and labels tensors or dicts of - tensors and returns tuple of: + model_fn: Model function. Follows the signature: + * Args: + * `features` are single `Tensor` or `dict` of `Tensor`s + (depending on data passed to `fit`), + * `labels` are `Tensor` or `dict` of `Tensor`s (for multi-head + models). If mode is `ModeKeys.INFER`, `labels=None` will be + passed. If the `model_fn`'s signature does not accept + `mode`, the `model_fn` must still be able to handle + `labels=None`. + * `mode` specifies if this training, evaluation or + prediction. See `ModeKeys`. + * `params` is a `dict` of hyperparameters. Will receive what + is passed to Estimator in `params` parameter. This allows + to configure Estimators from hyper parameter tuning. + + * Returns: + `ModelFnOps` + + Also supports a legacy signature which returns tuple of: * predictions: `Tensor`, `SparseTensor` or dictionary of same. Can also be any type that is convertible to a `Tensor` or @@ -977,27 +965,12 @@ class Estimator(BaseEstimator): * loss: Scalar loss `Tensor`. * train_op: Training update `Tensor` or `Operation`. - Supports next three signatures for the function: + Supports next three signatures for the function: * `(features, labels) -> (predictions, loss, train_op)` * `(features, labels, mode) -> (predictions, loss, train_op)` * `(features, labels, mode, params) -> (predictions, loss, train_op)` - Where - - * `features` are single `Tensor` or `dict` of `Tensor`s - (depending on data passed to `fit`), - * `labels` are `Tensor` or `dict` of `Tensor`s (for multi-head - models). If mode is `ModeKeys.INFER`, `labels=None` will be - passed. If the `model_fn`'s signature does not accept - `mode`, the `model_fn` must still be able to handle - `labels=None`. - * `mode` represents if this training, evaluation or - prediction. See `ModeKeys`. - * `params` is a `dict` of hyperparameters. Will receive what - is passed to Estimator in `params` parameter. This allows - to configure Estimators from hyper parameter tunning. - model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. @@ -1039,8 +1012,8 @@ class Estimator(BaseEstimator): mode: ModeKeys Returns: - A ModelFnOps object. If model_fn returns a tuple, wraps them up in a - ModelFnOps object. + A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a + `ModelFnOps` object. Raises: ValueError: if model_fn returns invalid objects. @@ -1067,7 +1040,7 @@ class Estimator(BaseEstimator): mode=mode, predictions=model_fn_results[0], loss=model_fn_results[1], - training_op=model_fn_results[2]) + train_op=model_fn_results[2]) def _get_train_ops(self, features, labels): """Method that builds model graph and returns trainer ops. @@ -1081,10 +1054,9 @@ class Estimator(BaseEstimator): labels: `Tensor` or `dict` of `Tensor` objects. Returns: - Tuple of train `Operation` and loss `Tensor`. + `ModelFnOps` object. """ - model_fn_ops = self._call_model_fn(features, labels, ModeKeys.TRAIN) - return model_fn_ops.training_op, model_fn_ops.loss + return self._call_model_fn(features, labels, ModeKeys.TRAIN) def _get_eval_ops(self, features, labels, metrics): """Method that builds model graph and returns evaluation ops. @@ -1106,24 +1078,22 @@ class Estimator(BaseEstimator): `../metric_spec.py`. Returns: - metrics: `dict` of `Tensor` objects. + `ModelFnOps` object. Raises: ValueError: if `metrics` don't match `labels`. """ model_fn_ops = self._call_model_fn(features, labels, ModeKeys.EVAL) - all_metrics = model_fn_ops.default_metrics # Custom metrics should overwrite defaults. if metrics: - all_metrics.update(metrics) + model_fn_ops.eval_metric_ops.update(_make_metrics_ops( + metrics, features, labels, model_fn_ops.predictions)) - result = _make_metrics_ops(all_metrics, features, labels, - model_fn_ops.predictions) - if metric_key.MetricKey.LOSS not in result: - result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean( - model_fn_ops.loss) - return result + if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: + model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( + metrics_lib.streaming_mean(model_fn_ops.loss)) + return model_fn_ops def _get_predict_ops(self, features): """Method that builds model graph and returns prediction ops. @@ -1136,12 +1106,11 @@ class Estimator(BaseEstimator): features: `Tensor` or `dict` of `Tensor` objects. Returns: - predictions: `Tensor` or `dict` of `Tensor` objects. + `ModelFnOps` object. """ labels = tensor_signature.create_placeholders_from_signatures( self._labels_info) - model_fn_ops = self._call_model_fn(features, labels, ModeKeys.INFER) - return model_fn_ops.predictions + return self._call_model_fn(features, labels, ModeKeys.INFER) # For time of deprecation x,y from Estimator allow direct access. diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index a96c64d33f..8d671e93ff 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -31,6 +31,7 @@ import tensorflow as tf from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn.estimators import _sklearn from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators import model_fn _BOSTON_INPUT_DIM = 13 @@ -99,6 +100,24 @@ def linear_model_fn(features, labels, mode): return prediction, loss, train_op +def linear_model_fn_with_model_fn_ops(features, labels, mode): + """Same as linear_model_fn, but returns `ModelFnOps`.""" + assert mode in ( + tf.contrib.learn.ModeKeys.TRAIN, + tf.contrib.learn.ModeKeys.EVAL, + tf.contrib.learn.ModeKeys.INFER) + prediction, loss = ( + tf.contrib.learn.models.linear_regression_zero_init(features, labels) + ) + train_op = tf.contrib.layers.optimize_loss( + loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad', + learning_rate=0.1) + return model_fn.ModelFnOps(mode=mode, + predictions=prediction, + loss=loss, + train_op=train_op) + + def logistic_model_no_mode_fn(features, labels): if isinstance(labels, dict): labels = labels['labels'] @@ -142,8 +161,9 @@ class EstimatorTest(tf.test.TestCase): def testInvalidModelFn_no_train_op(self): def _invalid_model_fn(features, labels): # pylint: disable=unused-argument - tf.Variable(42.0, 'weight') - return None, None, None + w = tf.Variable(42.0, 'weight') + loss = 100.0 - w + return None, loss, None est = tf.contrib.learn.Estimator(model_fn=_invalid_model_fn) with self.assertRaisesRegexp(ValueError, 'Missing training_op'): est.fit(input_fn=boston_input_fn, steps=1) @@ -154,9 +174,10 @@ class EstimatorTest(tf.test.TestCase): w = tf.Variable(42.0, 'weight') loss = 100.0 - w train_op = w.assign_add(loss / 100.0) + predictions = loss if mode == tf.contrib.learn.ModeKeys.EVAL: loss = None - return None, loss, train_op + return predictions, loss, train_op est = tf.contrib.learn.Estimator(model_fn=_invalid_model_fn) est.fit(input_fn=boston_input_fn, steps=1) with self.assertRaisesRegexp(ValueError, 'Missing loss'): @@ -426,6 +447,17 @@ class EstimatorTest(tf.test.TestCase): output = list(est.predict(input_fn=input_fn)) self.assertEqual(len(output), boston.target.shape[0]) + def testWithModelFnOps(self): + """Test for model_fn that returns `ModelFnOps`.""" + est = tf.contrib.learn.Estimator(model_fn=linear_model_fn_with_model_fn_ops) + boston = tf.contrib.learn.datasets.load_boston() + est.fit(input_fn=boston_input_fn, steps=1) + input_fn = functools.partial(boston_input_fn, num_epochs=1) + scores = est.evaluate(input_fn=input_fn, steps=1) + self.assertIn('loss', scores.keys()) + output = list(est.predict(input_fn=input_fn)) + self.assertEqual(len(output), boston.target.shape[0]) + def testWrongInput(self): def other_input_fn(): return {'other': tf.constant([0, 0, 0])}, tf.constant([0, 0, 0]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index af566655ef..5de54de6ac 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -25,6 +25,7 @@ from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import metric_key +from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.session_bundle import exporter from tensorflow.python import summary @@ -200,6 +201,7 @@ class _Head(object): def logits_dimension(self): raise NotImplementedError("Calling an abstract method.") + @abc.abstractmethod def head_ops(self, features, labels, mode, train_op_fn, logits=None, logits_input=None): """Returns ops for a model_fn. @@ -214,71 +216,11 @@ class _Head(object): logits_input: tensor to build logits from. Returns: - `estimator.ModelFnOps` + `ModelFnOps`. Raises: ValueError: if mode is not recognized. """ - _check_logits_input_not_supported(logits, logits_input) - if mode == estimator.ModeKeys.TRAIN: - loss, additional_train_op = self._training_loss(features, labels, - logits, logits_input) - - train_op = train_op_fn(loss) - - if additional_train_op: - if train_op: - train_op = control_flow_ops.group(train_op, *additional_train_op) - else: - train_op = control_flow_ops.group(*additional_train_op) - - return estimator.ModelFnOps( - mode=estimator.ModeKeys.TRAIN, - loss=loss, - training_op=train_op, - default_metrics=self._default_metric(), - signature_fn=self._create_signature_fn()) - - if mode == estimator.ModeKeys.INFER: - return estimator.ModelFnOps( - mode=estimator.ModeKeys.INFER, - predictions=self._infer_op(logits, logits_input), - default_metrics=self._default_metric(), - signature_fn=self._create_signature_fn()) - - if mode == estimator.ModeKeys.EVAL: - predictions, loss = self._eval_op(features, labels, logits, logits_input) - return estimator.ModelFnOps( - mode=estimator.ModeKeys.EVAL, - predictions=predictions, - loss=loss, - default_metrics=self._default_metric(), - signature_fn=self._create_signature_fn()) - - raise ValueError("mode=%s unrecognized." % str(mode)) - - @abc.abstractmethod - def _training_loss(self, features, labels, logits=None, logits_input=None, - name="training_loss"): - raise NotImplementedError("Calling an abstract method.") - - @abc.abstractmethod - def _infer_op(self, logits=None, logits_input=None, name="infer_op"): - raise NotImplementedError("Calling an abstract method.") - - @abc.abstractmethod - def _eval_op(self, features, labels, logits=None, logits_input=None, - name="eval_op"): - raise NotImplementedError("Calling an abstract method.") - - @abc.abstractmethod - def _default_metric(self): - raise NotImplementedError("Calling an abstract method.") - - @abc.abstractmethod - def _create_signature_fn(self): - """Creates signature function for the Head. - """ raise NotImplementedError("Calling an abstract method.") @@ -319,8 +261,29 @@ class _RegressionHead(_Head): def logits_dimension(self): return self._logits_dimension - def _training_loss(self, features, labels, logits=None, - logits_input=None, name="training_loss"): + def head_ops(self, features, labels, mode, train_op_fn, logits=None, + logits_input=None): + """See `_Head`.""" + _check_mode_valid(mode) + _check_logits_input_not_supported(logits, logits_input) + predictions = self._predictions(logits) + loss = (None if labels is None + else self._training_loss(features, labels, logits)) + train_op = (None if labels is None or train_op_fn is None + else self._train_op(features, labels, train_op_fn, logits)) + eval_metric_ops = (None if labels is None + else self._eval_metric_ops(features, labels, logits)) + signature_fn = self._signature_fn() + + return model_fn.ModelFnOps( + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + signature_fn=signature_fn) + + def _training_loss(self, features, labels, logits, name="training_loss"): """Returns training loss tensor for this head. Training loss is different from the loss reported on the tensorboard as we @@ -336,24 +299,17 @@ class _RegressionHead(_Head): labels: either a tensor for labels or in multihead case, a dict of string to labels tensor. logits: logits, a float tensor. - logits_input: Output of last hidden layer. name: Op name. Returns: - A tuple of training Loss and additional_train_op (possibly None) + A loss `Tensor`. """ labels = _check_labels(labels, self._label_name) - centered_bias_step = None if self._enable_centered_bias: logits = nn.bias_add(logits, _centered_bias( self.logits_dimension, self._centered_bias_weight_collection)) - centered_bias_step = [_centered_bias_step( - self.logits_dimension, - self._centered_bias_weight_collection, - labels, - self._train_loss_fn)] loss_unweighted = self._train_loss_fn(logits, labels) loss, weighted_average_loss = _loss( @@ -362,32 +318,54 @@ class _RegressionHead(_Head): name=name) summary.scalar( _head_prefixed(self._head_name, "loss"), weighted_average_loss) - return loss, centered_bias_step + return loss + + def _train_op(self, features, labels, train_op_fn, logits): + """Returns op for the training step.""" + loss = self._training_loss(features, labels, logits) + train_op = train_op_fn(loss) - def _eval_op(self, features, labels, logits=None, logits_input=None, - name="eval_op"): - labels = _check_labels(labels, self._label_name) if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( + centered_bias_step = [_centered_bias_step( self.logits_dimension, - self._centered_bias_weight_collection)) - loss_unweighted = self._eval_loss_fn(logits, labels) - loss, _ = _loss(loss_unweighted, - _weight_tensor(features, self._weight_column_name), - name=name) + self._centered_bias_weight_collection, + labels, + self._train_loss_fn)] + train_op = control_flow_ops.group(train_op, *centered_bias_step) + + return train_op + + def _eval_metric_ops(self, features, labels, logits): + """Returns a dict of metric ops keyed by name.""" + labels = _check_labels(labels, self._label_name) + predictions = self._predictions(logits) + return estimator._make_metrics_ops( # pylint: disable=protected-access + self._default_metrics(), features, labels, predictions) - predictions = self._logits_to_prediction(logits) + def _predictions(self, logits): + """Returns a dict of predictions. - return predictions, loss + Args: + logits: logits `Tensor` before applying possible centered bias. - def _infer_op(self, logits=None, logits_input=None): + Returns: + Dict of prediction `Tensor` keyed by `PredictionKey`. + """ if self._enable_centered_bias: logits = nn.bias_add(logits, _centered_bias( self.logits_dimension, self._centered_bias_weight_collection)) - return self._logits_to_prediction(logits) + return self._logits_to_predictions(logits) + + def _logits_to_predictions(self, logits): + """Returns a dict of predictions. + + Args: + logits: logits `Tensor` after applying possible centered bias. - def _logits_to_prediction(self, logits=None): + Returns: + Dict of prediction `Tensor` keyed by `PredictionKey`. + """ predictions = {} if self.logits_dimension == 1: predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze( @@ -396,8 +374,8 @@ class _RegressionHead(_Head): predictions[prediction_key.PredictionKey.SCORES] = logits return predictions - # pylint: disable=undefined-variable - def _create_signature_fn(self): + def _signature_fn(self): + """Returns the signature_fn to be used in exporting.""" def _regression_signature_fn(examples, unused_features, predictions): if isinstance(predictions, dict): score = predictions[prediction_key.PredictionKey.SCORES] @@ -410,7 +388,8 @@ class _RegressionHead(_Head): return default_signature, {} return _regression_signature_fn - def _default_metric(self): + def _default_metrics(self): + """Returns a dict of `MetricSpec` keyed by `MetricKey`.""" return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): _weighted_average_loss_metric_spec( self._eval_loss_fn, @@ -464,8 +443,29 @@ class _MultiClassHead(_Head): def logits_dimension(self): return self._logits_dimension - def _training_loss(self, features, labels, logits=None, - logits_input=None, name="training_loss"): + def head_ops(self, features, labels, mode, train_op_fn, logits=None, + logits_input=None): + """See `_Head`.""" + _check_mode_valid(mode) + _check_logits_input_not_supported(logits, logits_input) + predictions = self._predictions(logits) + loss = (None if labels is None + else self._training_loss(features, labels, logits)) + train_op = (None if labels is None or train_op_fn is None + else self._train_op(features, labels, train_op_fn, logits)) + eval_metric_ops = (None if labels is None + else self._eval_metric_ops(features, labels, logits)) + signature_fn = self._signature_fn() + + return model_fn.ModelFnOps( + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + signature_fn=signature_fn) + + def _training_loss(self, features, labels, logits=None, name="training_loss"): """Returns training loss tensor for this head. Training loss is different from the loss reported on the tensorboard as we @@ -481,24 +481,17 @@ class _MultiClassHead(_Head): labels: either a tensor for labels or in multihead case, a dict of string to labels tensor. logits: logits, a float tensor. - logits_input: Output of last hidden layer. name: Op name. Returns: - A tuple of training Loss and additional_train_op (possibly None) + A loss `Tensor`. """ labels = _check_labels(labels, self._label_name) - centered_bias_step = None if self._enable_centered_bias: logits = nn.bias_add(logits, _centered_bias( self.logits_dimension, self._centered_bias_weight_collection)) - centered_bias_step = [_centered_bias_step( - self.logits_dimension, - self._centered_bias_weight_collection, - labels, - self._train_loss_fn)] loss_unweighted = self._train_loss_fn(logits, labels) loss, weighted_average_loss = _loss( @@ -507,33 +500,54 @@ class _MultiClassHead(_Head): name=name) summary.scalar( _head_prefixed(self._head_name, "loss"), weighted_average_loss) - return loss, centered_bias_step + return loss + + def _train_op(self, features, labels, train_op_fn, logits): + """Returns op for the training step.""" + loss = self._training_loss(features, labels, logits) + train_op = train_op_fn(loss) - def _eval_op(self, features, labels, logits=None, logits_input=None, - name="eval_op"): - labels = _check_labels(labels, self._label_name) if self._enable_centered_bias: - logits = nn.bias_add(logits, _centered_bias( + centered_bias_step = [_centered_bias_step( self.logits_dimension, - self._centered_bias_weight_collection)) - loss_unweighted = self._eval_loss_fn(logits, labels) - loss, _ = _loss(loss_unweighted, - _weight_tensor(features, self._weight_column_name), - name=name) + self._centered_bias_weight_collection, + labels, + self._train_loss_fn)] + train_op = control_flow_ops.group(train_op, *centered_bias_step) - predictions = self._logits_to_prediction(logits) + return train_op - return predictions, loss + def _eval_metric_ops(self, features, labels, logits): + """Returns a dict of metric ops keyed by name.""" + labels = _check_labels(labels, self._label_name) + predictions = self._predictions(logits) + return estimator._make_metrics_ops( # pylint: disable=protected-access + self._default_metrics(), features, labels, predictions) + + def _predictions(self, logits): + """Returns a dict of predictions. - def _infer_op(self, logits=None, logits_input=None): + Args: + logits: logits `Tensor` before applying possible centered bias. + + Returns: + Dict of prediction `Tensor` keyed by `PredictionKey`. + """ if self._enable_centered_bias: logits = nn.bias_add(logits, _centered_bias( self.logits_dimension, self._centered_bias_weight_collection)) - return self._logits_to_prediction(logits) + return self._logits_to_predictions(logits) + + def _logits_to_predictions(self, logits): + """Returns a dict of predictions. - def _logits_to_prediction(self, logits=None): - # pylint: disable=missing-docstring + Args: + logits: logits `Tensor` after applying possible centered bias. + + Returns: + Dict of prediction `Tensor` keyed by `PredictionKey`. + """ predictions = {prediction_key.PredictionKey.LOGITS: logits} if self.logits_dimension == 1: predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( @@ -546,8 +560,8 @@ class _MultiClassHead(_Head): return predictions - def _create_signature_fn(self): - """See superclass.""" + def _signature_fn(self): + """Returns the signature_fn to be used in exporting.""" def _classification_signature_fn(examples, unused_features, predictions): """Servo signature function.""" if isinstance(predictions, dict): @@ -565,7 +579,8 @@ class _MultiClassHead(_Head): return default_signature, {} return _classification_signature_fn - def _default_metric(self): + def _default_metrics(self): + """Returns a dict of `MetricSpec` objects keyed by name.""" metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): _weighted_average_loss_metric_spec( self._eval_loss_fn, @@ -646,7 +661,8 @@ class _BinarySvmHead(_MultiClassHead): head_name=head_name, thresholds=thresholds) - def _logits_to_prediction(self, logits=None): + def _logits_to_predictions(self, logits): + """See `_MultiClassHead`.""" predictions = {} predictions[prediction_key.PredictionKey.LOGITS] = logits logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) @@ -655,7 +671,8 @@ class _BinarySvmHead(_MultiClassHead): return predictions - def _default_metric(self): + def _default_metrics(self): + """See `_MultiClassHead`.""" metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): _weighted_average_loss_metric_spec( self._eval_loss_fn, @@ -689,7 +706,8 @@ class _MultiLabelHead(_MultiClassHead): head_name=head_name, thresholds=thresholds) - def _logits_to_prediction(self, logits=None): + def _logits_to_predictions(self, logits): + """See `_MultiClassHead`.""" predictions = {prediction_key.PredictionKey.LOGITS: logits} if self.logits_dimension == 1: predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( @@ -741,6 +759,14 @@ def _check_logits_input_not_supported(logits, logits_input): "must pass logits") +def _check_mode_valid(mode): + """Raises ValueError if the given mode is invalid.""" + if (mode != model_fn.ModeKeys.TRAIN and + mode != model_fn.ModeKeys.INFER and + mode != model_fn.ModeKeys.EVAL): + raise ValueError("mode=%s unrecognized." % str(mode)) + + def _centered_bias(logits_dimension, weight_collection): """Creates and returns centered bias.""" centered_bias = variables.Variable( diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 11996a7b4e..a1175c327d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -106,7 +106,7 @@ def _linear_model_fn(features, labels, mode, params): sparse and use the 'sum' combiner. Returns: - An `estimator.ModelFnOps` instance. + A `ModelFnOps` instance. Raises: ValueError: If mode is not any of the `ModeKeys`. @@ -179,7 +179,7 @@ def sdca_model_fn(features, labels, mode, params): model weights. Returns: - An `estimator.ModelFnOps` instance. + A `ModelFnOps` instance. Raises: ValueError: If `optimizer` is not an `SDCAOptimizer` instance. diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py new file mode 100644 index 0000000000..3f9351ce22 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -0,0 +1,125 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Classes and methods related to model_fn.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import six + +from tensorflow.contrib import framework as contrib_framework +from tensorflow.contrib.framework import get_graph_from_inputs + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops + + +class ModeKeys(object): + """Standard names for model modes. + + The following standard keys are defined: + + * `TRAIN`: training mode. + * `EVAL`: evaluation mode. + * `INFER`: inference mode. + """ + + TRAIN = 'train' + EVAL = 'eval' + INFER = 'infer' + + +# TODO(roumposg): Pass output_signature_fn instead of signature_fn. +class ModelFnOps(collections.namedtuple( + 'ModelFnOps', + ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn'])): + """Ops returned from a model_fn.""" + + def __new__(cls, mode, predictions=None, loss=None, train_op=None, + eval_metric_ops=None, signature_fn=None): + """Creates a validated `ModelFnOps` instance. + + Args: + mode: One of `ModeKeys`. Specifies if this training, evaluation or + prediction. + predictions: Predictions `Tensor` or dict of `Tensor`. + loss: Training loss `Tensor`. + train_op: Op for the training step. + eval_metric_ops: Dict of metric results keyed by name. The values of the + dict are the results of calling a metric function, such as `Tensor`. + signature_fn: The signature_fn used for exporting. + + Returns: + A validated `ModelFnOps` object. + + Raises: + ValueError: If validation fails. + """ + # Assert all ops are from the same graph. + get_graph_from_inputs((predictions, loss, train_op)) + + # Validate train_op. + if train_op is None: + if mode == ModeKeys.TRAIN: + raise ValueError('Missing training_op.') + elif not isinstance(train_op, ops.Operation): + # TODO(ptucker): Should this be allowed? Consider raising error. + train_op = ops.convert_to_tensor(train_op).op + + # Validate loss. + if loss is None: + if mode in (ModeKeys.TRAIN, ModeKeys.EVAL): + raise ValueError('Missing loss.') + else: + loss = ops.convert_to_tensor(loss) + loss_shape = loss.get_shape() + if loss_shape.num_elements() not in (None, 1): + raise ValueError('Loss must be scalar: %s.' % loss) + if not loss_shape.is_compatible_with(tensor_shape.scalar()): + loss = array_ops.reshape(loss, []) + + # Validate predictions. + if predictions is None: + if mode == ModeKeys.INFER or mode == ModeKeys.EVAL: + raise ValueError('Missing predictions.') + else: + if isinstance(predictions, dict): + predictions = { + k: contrib_framework.convert_to_tensor_or_sparse_tensor(v) + for k, v in six.iteritems(predictions) + } + else: + predictions = contrib_framework.convert_to_tensor_or_sparse_tensor( + predictions) + + # Validate eval_metric_ops + if eval_metric_ops is None: + eval_metric_ops = {} + else: + if not isinstance(eval_metric_ops, dict): + raise ValueError('eval_metric_ops must be a dict.') + + # validate signature_fn + if signature_fn: + if not callable(signature_fn): + raise ValueError('signature_fn is not callable.') + + return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op, + eval_metric_ops, signature_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 98bdced34f..fe2f8320f2 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.core.protobuf import saver_pb2 @@ -301,7 +302,19 @@ def _export_estimator(estimator, if (not features) and (examples is None): raise ValueError('Either features or examples must be defined.') - predictions = estimator._get_predict_ops(features) + # The default return type of _get_predict_ops is ModelFnOps. But there are + # some subclasses of tf.contrib.learn.Estimator which override this + # method and use the legacy signature, namely _get_predict_ops returns a + # `predictions` Tensor or dict or Tensors. The following else-statement + # code covers these cases, but will soon be deleted after the subclasses + # are updated. + # TODO(b/32664904): Update subclasses and delete the else-statement. + infer_ops = estimator._get_predict_ops(features) + if isinstance(infer_ops, model_fn.ModelFnOps): # Default signature + predictions = infer_ops.predictions + else: # Legacy signature + predictions = infer_ops + if prediction_key is not None: predictions = predictions[prediction_key] |