aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/__init__.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn.py13
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py78
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py9
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py209
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py38
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py268
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/model_fn.py125
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py15
10 files changed, 481 insertions, 280 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index ef5a16d7dc..dc9aecff71 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -275,12 +275,12 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import BaseEstim
from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
-from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
+from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModeKeys
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index 7ca694f299..f8c0a6fe5d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.contrib.losses.python.losses import loss_ops
from tensorflow.python import summary
@@ -236,7 +237,7 @@ def _dnn_classifier_model_fn(features, labels, mode, params):
activation_fn=activation_fn,
variables_collections=[parent_scope],
scope=scope)
- if dropout is not None and mode == estimator.ModeKeys.TRAIN:
+ if dropout is not None and mode == model_fn.ModeKeys.TRAIN:
net = layers.dropout(
net,
keep_prob=(1.0 - dropout))
@@ -257,7 +258,7 @@ def _dnn_classifier_model_fn(features, labels, mode, params):
if enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(num_label_columns))
- if mode == estimator.ModeKeys.TRAIN:
+ if mode == model_fn.ModeKeys.TRAIN:
labels = _reshape_labels(labels)
weights = _get_weight_tensor(features, weight_column_name)
training_loss = loss_fn(logits, labels, weights=weights)
@@ -279,7 +280,7 @@ def _dnn_classifier_model_fn(features, labels, mode, params):
return None, loss, control_flow_ops.group(*train_ops)
- elif mode == estimator.ModeKeys.EVAL:
+ elif mode == model_fn.ModeKeys.EVAL:
predictions = _predictions(logits=logits, n_classes=n_classes)
labels = _reshape_labels(labels)
@@ -524,6 +525,12 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
return (pred[_PROBABILITIES] for pred in preds)
return preds[_PROBABILITIES]
+ def _get_predict_ops(self, features):
+ """See `Estimator` class."""
+ # This method exists to support some models that use the legacy interface.
+ # pylint: disable=protected-access
+ return self._estimator._get_predict_ops(features)
+
def get_variable_names(self):
"""Returns list of all variable names in this model.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 9bb2d353ec..0a9d0a3d1c 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import composable_model
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python.framework import ops
@@ -228,11 +229,10 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
with ops.get_default_graph().colocate_with(global_step):
return state_ops.assign_add(global_step, 1).op
- model_fn_ops = self._head.head_ops(features, labels,
- estimator.ModeKeys.TRAIN,
- _make_training_op,
- logits=logits)
- return model_fn_ops.training_op, model_fn_ops.loss
+ return self._head.head_ops(features, labels,
+ model_fn.ModeKeys.TRAIN,
+ _make_training_op,
+ logits=logits)
def _get_eval_ops(self, features, labels, metrics=None):
"""See base class."""
@@ -240,32 +240,30 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
features, labels = self._feature_engineering_fn(features, labels)
logits = self._logits(features)
- model_fn_ops = self._head.head_ops(features, labels,
- estimator.ModeKeys.EVAL, None,
- logits=logits)
- all_metrics = model_fn_ops.default_metrics
+ eval_ops = self._head.head_ops(features, labels,
+ model_fn.ModeKeys.EVAL, None,
+ logits=logits)
+ custom_metrics = {}
if metrics:
for name, metric in six.iteritems(metrics):
if not isinstance(name, tuple):
# TODO(zakaria): remove once deprecation is finished (b/31229024)
- all_metrics[(name, self._default_prediction_key)] = metric
+ custom_metrics[(name, self._default_prediction_key)] = metric
else:
- all_metrics[name] = metric
+ custom_metrics[name] = metric
# TODO(zakaria): Remove this once we refactor this class to delegate
# to estimator.
- # pylint: disable=protected-access
- result = estimator._make_metrics_ops(all_metrics, features, labels,
- model_fn_ops.predictions)
- return result
+ eval_ops.eval_metric_ops.update(estimator._make_metrics_ops( # pylint: disable=protected-access
+ custom_metrics, features, labels, eval_ops.predictions))
+ return eval_ops
def _get_predict_ops(self, features):
"""See base class."""
features = self._get_feature_dict(features)
features, _ = self._feature_engineering_fn(features, None)
logits = self._logits(features)
- model_fn_ops = self._head.head_ops(features, None, estimator.ModeKeys.INFER,
- None, logits=logits)
- return model_fn_ops.predictions
+ return self._head.head_ops(features, None, model_fn.ModeKeys.INFER,
+ None, logits=logits)
@deprecated(
"2016-09-23",
@@ -458,7 +456,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params):
activation_fn=dnn_activation_fn,
variables_collections=[dnn_parent_scope],
scope=scope)
- if dnn_dropout is not None and mode == estimator.ModeKeys.TRAIN:
+ if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN:
net = layers.dropout(
net,
keep_prob=(1.0 - dnn_dropout))
@@ -785,9 +783,9 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
def _get_predict_ops(self, features):
"""See `Estimator` class."""
+ # This method exists to support some models that use the legacy interface.
# pylint: disable=protected-access
- return self._estimator._get_predict_ops(features)[
- prediction_key.PredictionKey.PROBABILITIES]
+ return self._estimator._get_predict_ops(features)
def get_variable_names(self):
"""Returns list of all variable names in this model.
@@ -1041,8 +1039,36 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
default_prediction_key=prediction_key.PredictionKey.SCORES,
enable_centered_bias=enable_centered_bias)
- def _get_predict_ops(self, features):
- """See base class."""
- return super(
- DNNLinearCombinedRegressor,
- self)._get_predict_ops(features)[prediction_key.PredictionKey.SCORES]
+ @deprecated_arg_values(
+ estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
+ as_iterable=False)
+ def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=True):
+ """Runs inference to determine the predicted class."""
+ key = prediction_key.PredictionKey.SCORES
+ preds = super(DNNLinearCombinedRegressor, self).predict(
+ x=x,
+ input_fn=input_fn,
+ batch_size=batch_size,
+ outputs=[key],
+ as_iterable=as_iterable)
+ if as_iterable:
+ return _as_iterable(preds, output=key)
+ return preds[key]
+
+ def export(self,
+ export_dir,
+ input_fn=None,
+ input_feature_key=None,
+ use_deprecated_input_fn=True,
+ signature_fn=None,
+ default_batch_size=None,
+ exports_to_keep=None):
+ return super(DNNLinearCombinedRegressor, self).export(
+ export_dir=export_dir,
+ input_fn=input_fn,
+ input_feature_key=input_feature_key,
+ use_deprecated_input_fn=use_deprecated_input_fn,
+ signature_fn=signature_fn,
+ prediction_key=prediction_key.PredictionKey.SCORES,
+ default_batch_size=default_batch_size,
+ exports_to_keep=exports_to_keep)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index 953e810bf6..f69ab30f7a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -572,10 +572,13 @@ def _get_dynamic_rnn_model_fn(cell,
predict_probabilities)
loss = _single_value_loss(
rnn_activations, labels, sequence_length, target_column, features)
+ # TODO(roumposg): Return eval_metric_ops here, instead of default_metrics.
default_metrics = _get_default_metrics(
problem_type, prediction_type, sequence_length)
prediction_dict[RNNKeys.FINAL_STATE_KEY] = final_state
- training_op = optimizers.optimize_loss(
+ eval_metric_ops = estimator._make_metrics_ops( # pylint: disable=protected-access
+ default_metrics, features, labels, prediction_dict)
+ train_op = optimizers.optimize_loss(
loss=loss,
global_step=None,
learning_rate=learning_rate,
@@ -585,8 +588,8 @@ def _get_dynamic_rnn_model_fn(cell,
return estimator.ModelFnOps(mode=mode,
predictions=prediction_dict,
loss=loss,
- training_op=training_op,
- default_metrics=default_metrics)
+ train_op=train_op,
+ eval_metric_ops=eval_metric_ops)
return _dynamic_rnn_model_fn
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 3daeb28253..b92ed3ccb0 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import abc
-import collections
import copy
import inspect
import itertools
@@ -37,7 +36,6 @@ from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework import deprecated_args
-from tensorflow.contrib.framework import get_graph_from_inputs
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.learn.python.learn import evaluable
@@ -47,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
from tensorflow.contrib.learn.python.learn.estimators import metric_key
+from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
@@ -56,8 +55,6 @@ from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_setter
@@ -78,81 +75,13 @@ SCIKIT_DECOUPLE_INSTRUCTIONS = (
' est = Estimator(...) -> est = SKCompat(Estimator(...))')
-class ModeKeys(object):
- """Standard names for model modes.
+# TODO(roumposg): Migrate external users to tf.learn.contrib.ModeKeys and delete
+# this.
+ModeKeys = model_fn_lib.ModeKeys # pylint: disable=invalid-name
- The following standard keys are defined:
- * `TRAIN`: training mode.
- * `EVAL`: evaluation mode.
- * `INFER`: inference mode.
- """
-
- TRAIN = 'train'
- EVAL = 'eval'
- INFER = 'infer'
-
-
-class ModelFnOps(
- collections.namedtuple('ModelFnOps', ['predictions', 'loss', 'training_op',
- 'default_metrics', 'signature_fn'])):
-
- def __new__(cls, mode, predictions=None, loss=None, training_op=None,
- default_metrics=None, signature_fn=None):
- # Assert all ops are from the same graph.
- get_graph_from_inputs((predictions, loss, training_op))
-
- # Validate training_op.
- if training_op is None:
- if mode == ModeKeys.TRAIN:
- raise ValueError('Missing training_op.')
- elif not isinstance(training_op, ops.Operation):
- # TODO(ptucker): Should this be allowed? Consider raising error.
- training_op = ops.convert_to_tensor(training_op).op
-
- # Validate loss.
- if loss is None:
- if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
- raise ValueError('Missing loss.')
- else:
- loss = ops.convert_to_tensor(loss)
- loss_shape = loss.get_shape()
- if loss_shape.num_elements() not in (None, 1):
- raise ValueError('Loss must be scalar: %s.' % loss)
- if not loss_shape.is_compatible_with(tensor_shape.scalar()):
- loss = array_ops.reshape(loss, [])
-
- # Validate predictions.
- if predictions is None:
- if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
- raise ValueError('Missing predictions.')
- else:
- if isinstance(predictions, dict):
- predictions = {
- k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
- for k, v in six.iteritems(predictions)
- }
- else:
- predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
- predictions)
-
- # Validate default_metrics
- if default_metrics is None:
- default_metrics = {}
- else:
- if not isinstance(default_metrics, dict):
- raise ValueError('default_metrics must be a dict.')
- for k, v in default_metrics.items():
- if not isinstance(v, metric_spec.MetricSpec):
- raise ValueError('Metric with key=%s is not MetricSpec' % k)
-
- # validate signature_fn
- if signature_fn:
- if not callable(signature_fn):
- raise ValueError('signature_fn is not callable.')
-
- return super(ModelFnOps, cls).__new__(cls, predictions, loss, training_op,
- default_metrics, signature_fn)
+# TODO(roumposg): Migrate external users to model.ModelFnOps and delete this.
+ModelFnOps = model_fn_lib.ModelFnOps # pylint: disable=invalid-name
def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
@@ -645,7 +574,7 @@ class BaseEstimator(
labels: `Tensor` or `dict` of `Tensor` objects.
Returns:
- Tuple of train `Operation` and loss `Tensor`.
+ A `ModelFnOps` object.
"""
pass
@@ -657,7 +586,7 @@ class BaseEstimator(
features: `Tensor` or `dict` of `Tensor` objects.
Returns:
- predictions: `Tensor` or `dict` of `Tensor` objects.
+ A `ModelFnOps` object.
"""
pass
@@ -679,7 +608,7 @@ class BaseEstimator(
`../metric_spec.py`.
Returns:
- metrics: `dict` of `Tensor` objects.
+ A `ModelFnOps` object.
"""
raise NotImplementedError('_get_eval_ops not implemented in BaseEstimator')
@@ -768,7 +697,23 @@ class BaseEstimator(
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
- train_op, loss_op = self._get_train_ops(features, labels)
+
+ # The default return type of _get_train_ops is ModelFnOps. But there are
+ # some subclasses of tf.contrib.learn.Estimator which override this
+ # method and use the legacy signature, namely _get_train_ops returns a
+ # (train_op, loss) tuple. The following else-statement code covers these
+ # cases, but will soon be deleted after the subclasses are updated.
+ # TODO(b/32664904): Update subclasses and delete the else-statement.
+ train_ops = self._get_train_ops(features, labels)
+ if isinstance(train_ops, ModelFnOps): # Default signature
+ train_op = train_ops.train_op
+ loss_op = train_ops.loss
+ else: # Legacy signature
+ if len(train_ops) != 2:
+ raise ValueError('Expected a tuple of train_op and loss, got {}'.
+ format(train_ops))
+ train_op = train_ops[0]
+ loss_op = train_ops[1]
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
@@ -845,7 +790,20 @@ class BaseEstimator(
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
- eval_dict = self._get_eval_ops(features, labels, metrics)
+
+ # The default return type of _get_eval_ops is ModelFnOps. But there are
+ # some subclasses of tf.contrib.learn.Estimator which override this
+ # method and use the legacy signature, namely _get_eval_ops returns an
+ # `eval_dict` dictionary of Tensors. The following else-statement code
+ # covers these cases, but will soon be deleted after the subclasses are
+ # updated.
+ # TODO(b/32664904): Update subclasses and delete the else-statement.
+ eval_ops = self._get_eval_ops(features, labels, metrics)
+ if isinstance(eval_ops, ModelFnOps): # Default signature
+ eval_dict = eval_ops.eval_metric_ops
+ else: # Legacy signature
+ eval_dict = eval_ops
+
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
eval_results, current_global_step = graph_actions.evaluate(
graph=g,
@@ -878,7 +836,20 @@ class BaseEstimator(
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
- predictions = self._get_predict_ops(features)
+
+ # The default return type of _get_predict_ops is ModelFnOps. But there are
+ # some subclasses of tf.contrib.learn.Estimator which override this
+ # method and use the legacy signature, namely _get_predict_ops returns a
+ # `predictions` Tensor or dict or Tensors. The following else-statement
+ # code covers these cases, but will soon be deleted after the subclasses
+ # are updated.
+ # TODO(b/32664904): Update subclasses and delete the else-statement.
+ infer_ops = self._get_predict_ops(features)
+ if isinstance(infer_ops, ModelFnOps): # Default signature
+ predictions = infer_ops.predictions
+ else: # Legacy signature
+ predictions = infer_ops
+
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
@@ -965,11 +936,28 @@ class Estimator(BaseEstimator):
config=None,
params=None,
feature_engineering_fn=None):
- """Constructs an Estimator instance.
+ """Constructs an `Estimator` instance.
Args:
- model_fn: Model function, takes features and labels tensors or dicts of
- tensors and returns tuple of:
+ model_fn: Model function. Follows the signature:
+ * Args:
+ * `features` are single `Tensor` or `dict` of `Tensor`s
+ (depending on data passed to `fit`),
+ * `labels` are `Tensor` or `dict` of `Tensor`s (for multi-head
+ models). If mode is `ModeKeys.INFER`, `labels=None` will be
+ passed. If the `model_fn`'s signature does not accept
+ `mode`, the `model_fn` must still be able to handle
+ `labels=None`.
+ * `mode` specifies if this training, evaluation or
+ prediction. See `ModeKeys`.
+ * `params` is a `dict` of hyperparameters. Will receive what
+ is passed to Estimator in `params` parameter. This allows
+ to configure Estimators from hyper parameter tuning.
+
+ * Returns:
+ `ModelFnOps`
+
+ Also supports a legacy signature which returns tuple of:
* predictions: `Tensor`, `SparseTensor` or dictionary of same.
Can also be any type that is convertible to a `Tensor` or
@@ -977,27 +965,12 @@ class Estimator(BaseEstimator):
* loss: Scalar loss `Tensor`.
* train_op: Training update `Tensor` or `Operation`.
- Supports next three signatures for the function:
+ Supports next three signatures for the function:
* `(features, labels) -> (predictions, loss, train_op)`
* `(features, labels, mode) -> (predictions, loss, train_op)`
* `(features, labels, mode, params) -> (predictions, loss, train_op)`
- Where
-
- * `features` are single `Tensor` or `dict` of `Tensor`s
- (depending on data passed to `fit`),
- * `labels` are `Tensor` or `dict` of `Tensor`s (for multi-head
- models). If mode is `ModeKeys.INFER`, `labels=None` will be
- passed. If the `model_fn`'s signature does not accept
- `mode`, the `model_fn` must still be able to handle
- `labels=None`.
- * `mode` represents if this training, evaluation or
- prediction. See `ModeKeys`.
- * `params` is a `dict` of hyperparameters. Will receive what
- is passed to Estimator in `params` parameter. This allows
- to configure Estimators from hyper parameter tunning.
-
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
@@ -1039,8 +1012,8 @@ class Estimator(BaseEstimator):
mode: ModeKeys
Returns:
- A ModelFnOps object. If model_fn returns a tuple, wraps them up in a
- ModelFnOps object.
+ A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a
+ `ModelFnOps` object.
Raises:
ValueError: if model_fn returns invalid objects.
@@ -1067,7 +1040,7 @@ class Estimator(BaseEstimator):
mode=mode,
predictions=model_fn_results[0],
loss=model_fn_results[1],
- training_op=model_fn_results[2])
+ train_op=model_fn_results[2])
def _get_train_ops(self, features, labels):
"""Method that builds model graph and returns trainer ops.
@@ -1081,10 +1054,9 @@ class Estimator(BaseEstimator):
labels: `Tensor` or `dict` of `Tensor` objects.
Returns:
- Tuple of train `Operation` and loss `Tensor`.
+ `ModelFnOps` object.
"""
- model_fn_ops = self._call_model_fn(features, labels, ModeKeys.TRAIN)
- return model_fn_ops.training_op, model_fn_ops.loss
+ return self._call_model_fn(features, labels, ModeKeys.TRAIN)
def _get_eval_ops(self, features, labels, metrics):
"""Method that builds model graph and returns evaluation ops.
@@ -1106,24 +1078,22 @@ class Estimator(BaseEstimator):
`../metric_spec.py`.
Returns:
- metrics: `dict` of `Tensor` objects.
+ `ModelFnOps` object.
Raises:
ValueError: if `metrics` don't match `labels`.
"""
model_fn_ops = self._call_model_fn(features, labels, ModeKeys.EVAL)
- all_metrics = model_fn_ops.default_metrics
# Custom metrics should overwrite defaults.
if metrics:
- all_metrics.update(metrics)
+ model_fn_ops.eval_metric_ops.update(_make_metrics_ops(
+ metrics, features, labels, model_fn_ops.predictions))
- result = _make_metrics_ops(all_metrics, features, labels,
- model_fn_ops.predictions)
- if metric_key.MetricKey.LOSS not in result:
- result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean(
- model_fn_ops.loss)
- return result
+ if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops:
+ model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = (
+ metrics_lib.streaming_mean(model_fn_ops.loss))
+ return model_fn_ops
def _get_predict_ops(self, features):
"""Method that builds model graph and returns prediction ops.
@@ -1136,12 +1106,11 @@ class Estimator(BaseEstimator):
features: `Tensor` or `dict` of `Tensor` objects.
Returns:
- predictions: `Tensor` or `dict` of `Tensor` objects.
+ `ModelFnOps` object.
"""
labels = tensor_signature.create_placeholders_from_signatures(
self._labels_info)
- model_fn_ops = self._call_model_fn(features, labels, ModeKeys.INFER)
- return model_fn_ops.predictions
+ return self._call_model_fn(features, labels, ModeKeys.INFER)
# For time of deprecation x,y from Estimator allow direct access.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index a96c64d33f..8d671e93ff 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -31,6 +31,7 @@ import tensorflow as tf
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
_BOSTON_INPUT_DIM = 13
@@ -99,6 +100,24 @@ def linear_model_fn(features, labels, mode):
return prediction, loss, train_op
+def linear_model_fn_with_model_fn_ops(features, labels, mode):
+ """Same as linear_model_fn, but returns `ModelFnOps`."""
+ assert mode in (
+ tf.contrib.learn.ModeKeys.TRAIN,
+ tf.contrib.learn.ModeKeys.EVAL,
+ tf.contrib.learn.ModeKeys.INFER)
+ prediction, loss = (
+ tf.contrib.learn.models.linear_regression_zero_init(features, labels)
+ )
+ train_op = tf.contrib.layers.optimize_loss(
+ loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
+ learning_rate=0.1)
+ return model_fn.ModelFnOps(mode=mode,
+ predictions=prediction,
+ loss=loss,
+ train_op=train_op)
+
+
def logistic_model_no_mode_fn(features, labels):
if isinstance(labels, dict):
labels = labels['labels']
@@ -142,8 +161,9 @@ class EstimatorTest(tf.test.TestCase):
def testInvalidModelFn_no_train_op(self):
def _invalid_model_fn(features, labels):
# pylint: disable=unused-argument
- tf.Variable(42.0, 'weight')
- return None, None, None
+ w = tf.Variable(42.0, 'weight')
+ loss = 100.0 - w
+ return None, loss, None
est = tf.contrib.learn.Estimator(model_fn=_invalid_model_fn)
with self.assertRaisesRegexp(ValueError, 'Missing training_op'):
est.fit(input_fn=boston_input_fn, steps=1)
@@ -154,9 +174,10 @@ class EstimatorTest(tf.test.TestCase):
w = tf.Variable(42.0, 'weight')
loss = 100.0 - w
train_op = w.assign_add(loss / 100.0)
+ predictions = loss
if mode == tf.contrib.learn.ModeKeys.EVAL:
loss = None
- return None, loss, train_op
+ return predictions, loss, train_op
est = tf.contrib.learn.Estimator(model_fn=_invalid_model_fn)
est.fit(input_fn=boston_input_fn, steps=1)
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
@@ -426,6 +447,17 @@ class EstimatorTest(tf.test.TestCase):
output = list(est.predict(input_fn=input_fn))
self.assertEqual(len(output), boston.target.shape[0])
+ def testWithModelFnOps(self):
+ """Test for model_fn that returns `ModelFnOps`."""
+ est = tf.contrib.learn.Estimator(model_fn=linear_model_fn_with_model_fn_ops)
+ boston = tf.contrib.learn.datasets.load_boston()
+ est.fit(input_fn=boston_input_fn, steps=1)
+ input_fn = functools.partial(boston_input_fn, num_epochs=1)
+ scores = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertIn('loss', scores.keys())
+ output = list(est.predict(input_fn=input_fn))
+ self.assertEqual(len(output), boston.target.shape[0])
+
def testWrongInput(self):
def other_input_fn():
return {'other': tf.constant([0, 0, 0])}, tf.constant([0, 0, 0])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index af566655ef..5de54de6ac 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -25,6 +25,7 @@ from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import metric_key
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.session_bundle import exporter
from tensorflow.python import summary
@@ -200,6 +201,7 @@ class _Head(object):
def logits_dimension(self):
raise NotImplementedError("Calling an abstract method.")
+ @abc.abstractmethod
def head_ops(self, features, labels, mode, train_op_fn, logits=None,
logits_input=None):
"""Returns ops for a model_fn.
@@ -214,71 +216,11 @@ class _Head(object):
logits_input: tensor to build logits from.
Returns:
- `estimator.ModelFnOps`
+ `ModelFnOps`.
Raises:
ValueError: if mode is not recognized.
"""
- _check_logits_input_not_supported(logits, logits_input)
- if mode == estimator.ModeKeys.TRAIN:
- loss, additional_train_op = self._training_loss(features, labels,
- logits, logits_input)
-
- train_op = train_op_fn(loss)
-
- if additional_train_op:
- if train_op:
- train_op = control_flow_ops.group(train_op, *additional_train_op)
- else:
- train_op = control_flow_ops.group(*additional_train_op)
-
- return estimator.ModelFnOps(
- mode=estimator.ModeKeys.TRAIN,
- loss=loss,
- training_op=train_op,
- default_metrics=self._default_metric(),
- signature_fn=self._create_signature_fn())
-
- if mode == estimator.ModeKeys.INFER:
- return estimator.ModelFnOps(
- mode=estimator.ModeKeys.INFER,
- predictions=self._infer_op(logits, logits_input),
- default_metrics=self._default_metric(),
- signature_fn=self._create_signature_fn())
-
- if mode == estimator.ModeKeys.EVAL:
- predictions, loss = self._eval_op(features, labels, logits, logits_input)
- return estimator.ModelFnOps(
- mode=estimator.ModeKeys.EVAL,
- predictions=predictions,
- loss=loss,
- default_metrics=self._default_metric(),
- signature_fn=self._create_signature_fn())
-
- raise ValueError("mode=%s unrecognized." % str(mode))
-
- @abc.abstractmethod
- def _training_loss(self, features, labels, logits=None, logits_input=None,
- name="training_loss"):
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def _infer_op(self, logits=None, logits_input=None, name="infer_op"):
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def _eval_op(self, features, labels, logits=None, logits_input=None,
- name="eval_op"):
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def _default_metric(self):
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def _create_signature_fn(self):
- """Creates signature function for the Head.
- """
raise NotImplementedError("Calling an abstract method.")
@@ -319,8 +261,29 @@ class _RegressionHead(_Head):
def logits_dimension(self):
return self._logits_dimension
- def _training_loss(self, features, labels, logits=None,
- logits_input=None, name="training_loss"):
+ def head_ops(self, features, labels, mode, train_op_fn, logits=None,
+ logits_input=None):
+ """See `_Head`."""
+ _check_mode_valid(mode)
+ _check_logits_input_not_supported(logits, logits_input)
+ predictions = self._predictions(logits)
+ loss = (None if labels is None
+ else self._training_loss(features, labels, logits))
+ train_op = (None if labels is None or train_op_fn is None
+ else self._train_op(features, labels, train_op_fn, logits))
+ eval_metric_ops = (None if labels is None
+ else self._eval_metric_ops(features, labels, logits))
+ signature_fn = self._signature_fn()
+
+ return model_fn.ModelFnOps(
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metric_ops=eval_metric_ops,
+ signature_fn=signature_fn)
+
+ def _training_loss(self, features, labels, logits, name="training_loss"):
"""Returns training loss tensor for this head.
Training loss is different from the loss reported on the tensorboard as we
@@ -336,24 +299,17 @@ class _RegressionHead(_Head):
labels: either a tensor for labels or in multihead case, a dict of string
to labels tensor.
logits: logits, a float tensor.
- logits_input: Output of last hidden layer.
name: Op name.
Returns:
- A tuple of training Loss and additional_train_op (possibly None)
+ A loss `Tensor`.
"""
labels = _check_labels(labels, self._label_name)
- centered_bias_step = None
if self._enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(
self.logits_dimension,
self._centered_bias_weight_collection))
- centered_bias_step = [_centered_bias_step(
- self.logits_dimension,
- self._centered_bias_weight_collection,
- labels,
- self._train_loss_fn)]
loss_unweighted = self._train_loss_fn(logits, labels)
loss, weighted_average_loss = _loss(
@@ -362,32 +318,54 @@ class _RegressionHead(_Head):
name=name)
summary.scalar(
_head_prefixed(self._head_name, "loss"), weighted_average_loss)
- return loss, centered_bias_step
+ return loss
+
+ def _train_op(self, features, labels, train_op_fn, logits):
+ """Returns op for the training step."""
+ loss = self._training_loss(features, labels, logits)
+ train_op = train_op_fn(loss)
- def _eval_op(self, features, labels, logits=None, logits_input=None,
- name="eval_op"):
- labels = _check_labels(labels, self._label_name)
if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
+ centered_bias_step = [_centered_bias_step(
self.logits_dimension,
- self._centered_bias_weight_collection))
- loss_unweighted = self._eval_loss_fn(logits, labels)
- loss, _ = _loss(loss_unweighted,
- _weight_tensor(features, self._weight_column_name),
- name=name)
+ self._centered_bias_weight_collection,
+ labels,
+ self._train_loss_fn)]
+ train_op = control_flow_ops.group(train_op, *centered_bias_step)
+
+ return train_op
+
+ def _eval_metric_ops(self, features, labels, logits):
+ """Returns a dict of metric ops keyed by name."""
+ labels = _check_labels(labels, self._label_name)
+ predictions = self._predictions(logits)
+ return estimator._make_metrics_ops( # pylint: disable=protected-access
+ self._default_metrics(), features, labels, predictions)
- predictions = self._logits_to_prediction(logits)
+ def _predictions(self, logits):
+ """Returns a dict of predictions.
- return predictions, loss
+ Args:
+ logits: logits `Tensor` before applying possible centered bias.
- def _infer_op(self, logits=None, logits_input=None):
+ Returns:
+ Dict of prediction `Tensor` keyed by `PredictionKey`.
+ """
if self._enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(
self.logits_dimension,
self._centered_bias_weight_collection))
- return self._logits_to_prediction(logits)
+ return self._logits_to_predictions(logits)
+
+ def _logits_to_predictions(self, logits):
+ """Returns a dict of predictions.
+
+ Args:
+ logits: logits `Tensor` after applying possible centered bias.
- def _logits_to_prediction(self, logits=None):
+ Returns:
+ Dict of prediction `Tensor` keyed by `PredictionKey`.
+ """
predictions = {}
if self.logits_dimension == 1:
predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
@@ -396,8 +374,8 @@ class _RegressionHead(_Head):
predictions[prediction_key.PredictionKey.SCORES] = logits
return predictions
- # pylint: disable=undefined-variable
- def _create_signature_fn(self):
+ def _signature_fn(self):
+ """Returns the signature_fn to be used in exporting."""
def _regression_signature_fn(examples, unused_features, predictions):
if isinstance(predictions, dict):
score = predictions[prediction_key.PredictionKey.SCORES]
@@ -410,7 +388,8 @@ class _RegressionHead(_Head):
return default_signature, {}
return _regression_signature_fn
- def _default_metric(self):
+ def _default_metrics(self):
+ """Returns a dict of `MetricSpec` keyed by `MetricKey`."""
return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
@@ -464,8 +443,29 @@ class _MultiClassHead(_Head):
def logits_dimension(self):
return self._logits_dimension
- def _training_loss(self, features, labels, logits=None,
- logits_input=None, name="training_loss"):
+ def head_ops(self, features, labels, mode, train_op_fn, logits=None,
+ logits_input=None):
+ """See `_Head`."""
+ _check_mode_valid(mode)
+ _check_logits_input_not_supported(logits, logits_input)
+ predictions = self._predictions(logits)
+ loss = (None if labels is None
+ else self._training_loss(features, labels, logits))
+ train_op = (None if labels is None or train_op_fn is None
+ else self._train_op(features, labels, train_op_fn, logits))
+ eval_metric_ops = (None if labels is None
+ else self._eval_metric_ops(features, labels, logits))
+ signature_fn = self._signature_fn()
+
+ return model_fn.ModelFnOps(
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metric_ops=eval_metric_ops,
+ signature_fn=signature_fn)
+
+ def _training_loss(self, features, labels, logits=None, name="training_loss"):
"""Returns training loss tensor for this head.
Training loss is different from the loss reported on the tensorboard as we
@@ -481,24 +481,17 @@ class _MultiClassHead(_Head):
labels: either a tensor for labels or in multihead case, a dict of string
to labels tensor.
logits: logits, a float tensor.
- logits_input: Output of last hidden layer.
name: Op name.
Returns:
- A tuple of training Loss and additional_train_op (possibly None)
+ A loss `Tensor`.
"""
labels = _check_labels(labels, self._label_name)
- centered_bias_step = None
if self._enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(
self.logits_dimension,
self._centered_bias_weight_collection))
- centered_bias_step = [_centered_bias_step(
- self.logits_dimension,
- self._centered_bias_weight_collection,
- labels,
- self._train_loss_fn)]
loss_unweighted = self._train_loss_fn(logits, labels)
loss, weighted_average_loss = _loss(
@@ -507,33 +500,54 @@ class _MultiClassHead(_Head):
name=name)
summary.scalar(
_head_prefixed(self._head_name, "loss"), weighted_average_loss)
- return loss, centered_bias_step
+ return loss
+
+ def _train_op(self, features, labels, train_op_fn, logits):
+ """Returns op for the training step."""
+ loss = self._training_loss(features, labels, logits)
+ train_op = train_op_fn(loss)
- def _eval_op(self, features, labels, logits=None, logits_input=None,
- name="eval_op"):
- labels = _check_labels(labels, self._label_name)
if self._enable_centered_bias:
- logits = nn.bias_add(logits, _centered_bias(
+ centered_bias_step = [_centered_bias_step(
self.logits_dimension,
- self._centered_bias_weight_collection))
- loss_unweighted = self._eval_loss_fn(logits, labels)
- loss, _ = _loss(loss_unweighted,
- _weight_tensor(features, self._weight_column_name),
- name=name)
+ self._centered_bias_weight_collection,
+ labels,
+ self._train_loss_fn)]
+ train_op = control_flow_ops.group(train_op, *centered_bias_step)
- predictions = self._logits_to_prediction(logits)
+ return train_op
- return predictions, loss
+ def _eval_metric_ops(self, features, labels, logits):
+ """Returns a dict of metric ops keyed by name."""
+ labels = _check_labels(labels, self._label_name)
+ predictions = self._predictions(logits)
+ return estimator._make_metrics_ops( # pylint: disable=protected-access
+ self._default_metrics(), features, labels, predictions)
+
+ def _predictions(self, logits):
+ """Returns a dict of predictions.
- def _infer_op(self, logits=None, logits_input=None):
+ Args:
+ logits: logits `Tensor` before applying possible centered bias.
+
+ Returns:
+ Dict of prediction `Tensor` keyed by `PredictionKey`.
+ """
if self._enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(
self.logits_dimension,
self._centered_bias_weight_collection))
- return self._logits_to_prediction(logits)
+ return self._logits_to_predictions(logits)
+
+ def _logits_to_predictions(self, logits):
+ """Returns a dict of predictions.
- def _logits_to_prediction(self, logits=None):
- # pylint: disable=missing-docstring
+ Args:
+ logits: logits `Tensor` after applying possible centered bias.
+
+ Returns:
+ Dict of prediction `Tensor` keyed by `PredictionKey`.
+ """
predictions = {prediction_key.PredictionKey.LOGITS: logits}
if self.logits_dimension == 1:
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
@@ -546,8 +560,8 @@ class _MultiClassHead(_Head):
return predictions
- def _create_signature_fn(self):
- """See superclass."""
+ def _signature_fn(self):
+ """Returns the signature_fn to be used in exporting."""
def _classification_signature_fn(examples, unused_features, predictions):
"""Servo signature function."""
if isinstance(predictions, dict):
@@ -565,7 +579,8 @@ class _MultiClassHead(_Head):
return default_signature, {}
return _classification_signature_fn
- def _default_metric(self):
+ def _default_metrics(self):
+ """Returns a dict of `MetricSpec` objects keyed by name."""
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
@@ -646,7 +661,8 @@ class _BinarySvmHead(_MultiClassHead):
head_name=head_name,
thresholds=thresholds)
- def _logits_to_prediction(self, logits=None):
+ def _logits_to_predictions(self, logits):
+ """See `_MultiClassHead`."""
predictions = {}
predictions[prediction_key.PredictionKey.LOGITS] = logits
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
@@ -655,7 +671,8 @@ class _BinarySvmHead(_MultiClassHead):
return predictions
- def _default_metric(self):
+ def _default_metrics(self):
+ """See `_MultiClassHead`."""
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
@@ -689,7 +706,8 @@ class _MultiLabelHead(_MultiClassHead):
head_name=head_name,
thresholds=thresholds)
- def _logits_to_prediction(self, logits=None):
+ def _logits_to_predictions(self, logits):
+ """See `_MultiClassHead`."""
predictions = {prediction_key.PredictionKey.LOGITS: logits}
if self.logits_dimension == 1:
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
@@ -741,6 +759,14 @@ def _check_logits_input_not_supported(logits, logits_input):
"must pass logits")
+def _check_mode_valid(mode):
+ """Raises ValueError if the given mode is invalid."""
+ if (mode != model_fn.ModeKeys.TRAIN and
+ mode != model_fn.ModeKeys.INFER and
+ mode != model_fn.ModeKeys.EVAL):
+ raise ValueError("mode=%s unrecognized." % str(mode))
+
+
def _centered_bias(logits_dimension, weight_collection):
"""Creates and returns centered bias."""
centered_bias = variables.Variable(
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 11996a7b4e..a1175c327d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -106,7 +106,7 @@ def _linear_model_fn(features, labels, mode, params):
sparse and use the 'sum' combiner.
Returns:
- An `estimator.ModelFnOps` instance.
+ A `ModelFnOps` instance.
Raises:
ValueError: If mode is not any of the `ModeKeys`.
@@ -179,7 +179,7 @@ def sdca_model_fn(features, labels, mode, params):
model weights.
Returns:
- An `estimator.ModelFnOps` instance.
+ A `ModelFnOps` instance.
Raises:
ValueError: If `optimizer` is not an `SDCAOptimizer` instance.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
new file mode 100644
index 0000000000..3f9351ce22
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
@@ -0,0 +1,125 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Classes and methods related to model_fn."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import six
+
+from tensorflow.contrib import framework as contrib_framework
+from tensorflow.contrib.framework import get_graph_from_inputs
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+
+
+class ModeKeys(object):
+ """Standard names for model modes.
+
+ The following standard keys are defined:
+
+ * `TRAIN`: training mode.
+ * `EVAL`: evaluation mode.
+ * `INFER`: inference mode.
+ """
+
+ TRAIN = 'train'
+ EVAL = 'eval'
+ INFER = 'infer'
+
+
+# TODO(roumposg): Pass output_signature_fn instead of signature_fn.
+class ModelFnOps(collections.namedtuple(
+ 'ModelFnOps',
+ ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn'])):
+ """Ops returned from a model_fn."""
+
+ def __new__(cls, mode, predictions=None, loss=None, train_op=None,
+ eval_metric_ops=None, signature_fn=None):
+ """Creates a validated `ModelFnOps` instance.
+
+ Args:
+ mode: One of `ModeKeys`. Specifies if this training, evaluation or
+ prediction.
+ predictions: Predictions `Tensor` or dict of `Tensor`.
+ loss: Training loss `Tensor`.
+ train_op: Op for the training step.
+ eval_metric_ops: Dict of metric results keyed by name. The values of the
+ dict are the results of calling a metric function, such as `Tensor`.
+ signature_fn: The signature_fn used for exporting.
+
+ Returns:
+ A validated `ModelFnOps` object.
+
+ Raises:
+ ValueError: If validation fails.
+ """
+ # Assert all ops are from the same graph.
+ get_graph_from_inputs((predictions, loss, train_op))
+
+ # Validate train_op.
+ if train_op is None:
+ if mode == ModeKeys.TRAIN:
+ raise ValueError('Missing training_op.')
+ elif not isinstance(train_op, ops.Operation):
+ # TODO(ptucker): Should this be allowed? Consider raising error.
+ train_op = ops.convert_to_tensor(train_op).op
+
+ # Validate loss.
+ if loss is None:
+ if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):
+ raise ValueError('Missing loss.')
+ else:
+ loss = ops.convert_to_tensor(loss)
+ loss_shape = loss.get_shape()
+ if loss_shape.num_elements() not in (None, 1):
+ raise ValueError('Loss must be scalar: %s.' % loss)
+ if not loss_shape.is_compatible_with(tensor_shape.scalar()):
+ loss = array_ops.reshape(loss, [])
+
+ # Validate predictions.
+ if predictions is None:
+ if mode == ModeKeys.INFER or mode == ModeKeys.EVAL:
+ raise ValueError('Missing predictions.')
+ else:
+ if isinstance(predictions, dict):
+ predictions = {
+ k: contrib_framework.convert_to_tensor_or_sparse_tensor(v)
+ for k, v in six.iteritems(predictions)
+ }
+ else:
+ predictions = contrib_framework.convert_to_tensor_or_sparse_tensor(
+ predictions)
+
+ # Validate eval_metric_ops
+ if eval_metric_ops is None:
+ eval_metric_ops = {}
+ else:
+ if not isinstance(eval_metric_ops, dict):
+ raise ValueError('eval_metric_ops must be a dict.')
+
+ # validate signature_fn
+ if signature_fn:
+ if not callable(signature_fn):
+ raise ValueError('signature_fn is not callable.')
+
+ return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op,
+ eval_metric_ops, signature_fn)
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index 98bdced34f..fe2f8320f2 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import gc
from tensorflow.core.protobuf import saver_pb2
@@ -301,7 +302,19 @@ def _export_estimator(estimator,
if (not features) and (examples is None):
raise ValueError('Either features or examples must be defined.')
- predictions = estimator._get_predict_ops(features)
+ # The default return type of _get_predict_ops is ModelFnOps. But there are
+ # some subclasses of tf.contrib.learn.Estimator which override this
+ # method and use the legacy signature, namely _get_predict_ops returns a
+ # `predictions` Tensor or dict or Tensors. The following else-statement
+ # code covers these cases, but will soon be deleted after the subclasses
+ # are updated.
+ # TODO(b/32664904): Update subclasses and delete the else-statement.
+ infer_ops = estimator._get_predict_ops(features)
+ if isinstance(infer_ops, model_fn.ModelFnOps): # Default signature
+ predictions = infer_ops.predictions
+ else: # Legacy signature
+ predictions = infer_ops
+
if prediction_key is not None:
predictions = predictions[prediction_key]