aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/head.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py143
1 files changed, 38 insertions, 105 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index a9311a20f1..e344ee3c3e 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
@@ -49,20 +48,7 @@ def multi_class_head(n_classes,
Uses `sparse_softmax_cross_entropy` loss.
- The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`.
- In many applications, the shape is `[batch_size, n_classes]`.
-
- `labels` must be a dense `Tensor` with shape matching `logits`, namely
- `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
- `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
- `labels` must be an integer `Tensor` with values specifying the class index.
-
- If `weight_column` is specified, weights must be of shape
- `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
-
- The loss is the weighted sum over the input dimensions. Namely, if the input
- labels have shape `[batch_size, 1]`, the loss is the weighted sum over
- `batch_size`.
+ This head expects to be fed integer labels specifying the class index.
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
@@ -71,11 +57,11 @@ def multi_class_head(n_classes,
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
- label_vocabulary: A list or tuple of strings representing possible label
- values. If it is not given, that means labels are already encoded as an
- integer within [0, n_classes). If given, labels must be of string type and
- have any value in `label_vocabulary`. Note that errors will be raised if
- `label_vocabulary` is not provided but labels are strings.
+ label_vocabulary: A list of strings represents possible label values. If it
+ is not given, that means labels are already encoded as integer within
+ [0, n_classes). If given, labels must be string type and have any value in
+ `label_vocabulary`. Also there will be errors if vocabulary is not
+ provided and labels are string.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -98,20 +84,7 @@ def binary_classification_head(
This head uses `sigmoid_cross_entropy_with_logits` loss.
- The head expects `logits` with shape `[D0, D1, ... DN, 1]`.
- In many applications, the shape is `[batch_size, 1]`.
-
- `labels` must be a dense `Tensor` with shape matching `logits`, namely
- `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
- `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
- `labels` must be float `Tensor` with values in the interval `[0, 1]`.
-
- If `weight_column` is specified, weights must be of shape
- `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
-
- The loss is the weighted sum over the input dimensions. Namely, if the input
- labels have shape `[batch_size, 1]`, the loss is the weighted sum over
- `batch_size`.
+ This head expects to be fed float labels of shape `(batch_size, 1)`.
Args:
weight_column: A string or a `_NumericColumn` created by
@@ -123,11 +96,11 @@ def binary_classification_head(
generated for each threshold value. This threshold is applied to the
logistic values to determine the binary classification (i.e., above the
threshold is `true`, below is `false`.
- label_vocabulary: A list or tuple of strings representing possible label
- values. If it is not given, labels must be float with values within
- [0, 1]. If given, labels must be string type and have any value in
- `label_vocabulary`. Note that errors will be raised if `label_vocabulary`
- is not provided but labels are strings.
+ label_vocabulary: A list of strings represents possible label values. If it
+ is not given, that means labels are already encoded within [0, 1]. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. Also there will be errors if vocabulary is not
+ provided and labels are string.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -147,22 +120,9 @@ def binary_classification_head(
def regression_head(weight_column=None,
label_dimension=1,
name=None):
- """Creates a `_Head` for regression using the `mean_squared_error` loss.
-
- The loss is the weighted sum over all input dimensions. Namely, if the input
- labels have shape `[batch_size, label_dimension]`, the loss is the weighted
- sum over both `batch_size` and `label_dimension`.
-
- The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.
- In many applications, the shape is `[batch_size, label_dimension]`.
-
- The `labels` shape must match `logits`, namely
- `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape
- `[D0, D1, ... DN]` is also supported.
+ """Creates a `_Head` for regression using the mean squared loss.
- If `weight_column` is specified, weights must be of shape
- `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
- `[D0, D1, ... DN, label_dimension]`.
+ Uses `mean_squared_error` loss.
Args:
weight_column: A string or a `_NumericColumn` created by
@@ -196,29 +156,15 @@ def multi_label_head(n_classes,
or more associated labels, from a discrete set. This is distinct from
`multi_class_head` which has exactly one label per example.
- Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
- the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
- the loss is the average over `n_classes` and the weighted sum over
- `batch_size`.
-
- The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
- applications, the shape is `[batch_size, label_n_classes]`.
-
- Labels can be:
- * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
- * An integer `SparseTensor` of class indices. The `dense_shape` must be
- `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
- * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
- must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.
-
- If `weight_column` is specified, weights must be of shape
- `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
+ Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a
+ multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer
+ `SparseTensor` of class indices.
Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
`(labels, logits, features)` as arguments and returns unreduced loss with
- shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
- shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
- `label_vocabulary` to the input labels before passing them to `loss_fn`.
+ shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape
+ `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the
+ input labels before passing them to `loss_fn`.
Args:
n_classes: Number of classes, must be greater than 1 (for 1 class, use
@@ -245,7 +191,7 @@ def multi_label_head(n_classes,
An instance of `_Head` for multi-label classification.
Raises:
- ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid.
+ ValueError: if `n_classes` or `thresholds` is invalid.
"""
thresholds = tuple(thresholds) if thresholds else tuple()
if n_classes is None or n_classes < 2:
@@ -313,36 +259,26 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
indices=labels.indices,
values=label_ids_values,
dense_shape=labels.dense_shape)
- return math_ops.to_int64(
- sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
else:
- err_msg = (
- r'labels must be an integer SparseTensor with values in '
- r'[0, {})'.format(self._n_classes))
- assert_int = check_ops.assert_integer(
- labels.values, message=err_msg)
- assert_less = check_ops.assert_less(
- labels.values,
- ops.convert_to_tensor(self._n_classes, dtype=labels.dtype),
- message=err_msg)
- assert_greater = check_ops.assert_non_negative(
- labels.values, message=err_msg)
- with ops.control_dependencies(
- [assert_int, assert_less, assert_greater]):
- return math_ops.to_int64(
- sparse_ops.sparse_to_indicator(labels, self._n_classes))
- err_msg = (
- r'labels must be an integer indicator Tensor with values in [0, 1]')
- return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access,
+ label_ids = labels
+ return math_ops.to_int64(
+ sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
+ msg = ('labels shape must be [batch_size, {}]. '
+ 'Given: ').format(self._n_classes)
+ labels_shape = array_ops.shape(labels)
+ check_rank_op = control_flow_ops.Assert(
+ math_ops.equal(array_ops.rank(labels), 2),
+ data=[msg, labels_shape])
+ check_label_dim = control_flow_ops.Assert(
+ math_ops.equal(labels_shape[-1], self._n_classes),
+ data=[msg, labels_shape])
+ with ops.control_dependencies([check_rank_op, check_label_dim]):
+ return array_ops.identity(labels)
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
del mode # Unused for this head.
- logits = ops.convert_to_tensor(logits)
processed_labels = self._process_labels(labels)
- processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access
- labels=processed_labels, logits=logits,
- expected_labels_dimension=self.logits_dimension)
if self._loss_fn:
unweighted_loss = _call_loss_fn(
loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
@@ -354,8 +290,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
# Averages loss over classes.
unweighted_loss = math_ops.reduce_mean(
unweighted_loss, axis=-1, keep_dims=True)
- weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access,
- features=features, weight_column=self._weight_column, logits=logits)
+ weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access,
weighted_sum_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
# _weights() can return 1.
@@ -370,7 +305,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `Head`."""
with ops.name_scope(self._name, 'head'):
- logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access
+ logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access
# Predict.
pred_keys = prediction_keys.PredictionKeys
@@ -400,8 +335,6 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
# Eval.
if mode == model_fn.ModeKeys.EVAL:
- weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access,
- features=features, weight_column=self._weight_column, logits=logits)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
@@ -409,7 +342,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
eval_metric_ops=self._eval_metric_ops(
labels=processed_labels,
probabilities=probabilities,
- weights=weights,
+ weights=head_lib._weights(features, self._weight_column), # pylint:disable=protected-access,
weighted_sum_loss=weighted_sum_loss,
example_weight_sum=example_weight_sum))