diff options
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/head.py')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/head.py | 143 |
1 files changed, 38 insertions, 105 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index a9311a20f1..e344ee3c3e 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops @@ -49,20 +48,7 @@ def multi_class_head(n_classes, Uses `sparse_softmax_cross_entropy` loss. - The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. - In many applications, the shape is `[batch_size, n_classes]`. - - `labels` must be a dense `Tensor` with shape matching `logits`, namely - `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string - `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, - `labels` must be an integer `Tensor` with values specifying the class index. - - If `weight_column` is specified, weights must be of shape - `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. - - The loss is the weighted sum over the input dimensions. Namely, if the input - labels have shape `[batch_size, 1]`, the loss is the weighted sum over - `batch_size`. + This head expects to be fed integer labels specifying the class index. Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use @@ -71,11 +57,11 @@ def multi_class_head(n_classes, `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. - label_vocabulary: A list or tuple of strings representing possible label - values. If it is not given, that means labels are already encoded as an - integer within [0, n_classes). If given, labels must be of string type and - have any value in `label_vocabulary`. Note that errors will be raised if - `label_vocabulary` is not provided but labels are strings. + label_vocabulary: A list of strings represents possible label values. If it + is not given, that means labels are already encoded as integer within + [0, n_classes). If given, labels must be string type and have any value in + `label_vocabulary`. Also there will be errors if vocabulary is not + provided and labels are string. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -98,20 +84,7 @@ def binary_classification_head( This head uses `sigmoid_cross_entropy_with_logits` loss. - The head expects `logits` with shape `[D0, D1, ... DN, 1]`. - In many applications, the shape is `[batch_size, 1]`. - - `labels` must be a dense `Tensor` with shape matching `logits`, namely - `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string - `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, - `labels` must be float `Tensor` with values in the interval `[0, 1]`. - - If `weight_column` is specified, weights must be of shape - `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. - - The loss is the weighted sum over the input dimensions. Namely, if the input - labels have shape `[batch_size, 1]`, the loss is the weighted sum over - `batch_size`. + This head expects to be fed float labels of shape `(batch_size, 1)`. Args: weight_column: A string or a `_NumericColumn` created by @@ -123,11 +96,11 @@ def binary_classification_head( generated for each threshold value. This threshold is applied to the logistic values to determine the binary classification (i.e., above the threshold is `true`, below is `false`. - label_vocabulary: A list or tuple of strings representing possible label - values. If it is not given, labels must be float with values within - [0, 1]. If given, labels must be string type and have any value in - `label_vocabulary`. Note that errors will be raised if `label_vocabulary` - is not provided but labels are strings. + label_vocabulary: A list of strings represents possible label values. If it + is not given, that means labels are already encoded within [0, 1]. If + given, labels must be string type and have any value in + `label_vocabulary`. Also there will be errors if vocabulary is not + provided and labels are string. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -147,22 +120,9 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, name=None): - """Creates a `_Head` for regression using the `mean_squared_error` loss. - - The loss is the weighted sum over all input dimensions. Namely, if the input - labels have shape `[batch_size, label_dimension]`, the loss is the weighted - sum over both `batch_size` and `label_dimension`. - - The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. - In many applications, the shape is `[batch_size, label_dimension]`. - - The `labels` shape must match `logits`, namely - `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape - `[D0, D1, ... DN]` is also supported. + """Creates a `_Head` for regression using the mean squared loss. - If `weight_column` is specified, weights must be of shape - `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or - `[D0, D1, ... DN, label_dimension]`. + Uses `mean_squared_error` loss. Args: weight_column: A string or a `_NumericColumn` created by @@ -196,29 +156,15 @@ def multi_label_head(n_classes, or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. - Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over - the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, - the loss is the average over `n_classes` and the weighted sum over - `batch_size`. - - The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many - applications, the shape is `[batch_size, label_n_classes]`. - - Labels can be: - * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` - * An integer `SparseTensor` of class indices. The `dense_shape` must be - `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. - * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` - must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. - - If `weight_column` is specified, weights must be of shape - `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. + Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a + multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer + `SparseTensor` of class indices. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with - shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with - shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies - `label_vocabulary` to the input labels before passing them to `loss_fn`. + shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape + `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the + input labels before passing them to `loss_fn`. Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use @@ -245,7 +191,7 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid. + ValueError: if `n_classes` or `thresholds` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -313,36 +259,26 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access indices=labels.indices, values=label_ids_values, dense_shape=labels.dense_shape) - return math_ops.to_int64( - sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) else: - err_msg = ( - r'labels must be an integer SparseTensor with values in ' - r'[0, {})'.format(self._n_classes)) - assert_int = check_ops.assert_integer( - labels.values, message=err_msg) - assert_less = check_ops.assert_less( - labels.values, - ops.convert_to_tensor(self._n_classes, dtype=labels.dtype), - message=err_msg) - assert_greater = check_ops.assert_non_negative( - labels.values, message=err_msg) - with ops.control_dependencies( - [assert_int, assert_less, assert_greater]): - return math_ops.to_int64( - sparse_ops.sparse_to_indicator(labels, self._n_classes)) - err_msg = ( - r'labels must be an integer indicator Tensor with values in [0, 1]') - return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access, + label_ids = labels + return math_ops.to_int64( + sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) + msg = ('labels shape must be [batch_size, {}]. ' + 'Given: ').format(self._n_classes) + labels_shape = array_ops.shape(labels) + check_rank_op = control_flow_ops.Assert( + math_ops.equal(array_ops.rank(labels), 2), + data=[msg, labels_shape]) + check_label_dim = control_flow_ops.Assert( + math_ops.equal(labels_shape[-1], self._n_classes), + data=[msg, labels_shape]) + with ops.control_dependencies([check_rank_op, check_label_dim]): + return array_ops.identity(labels) def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode # Unused for this head. - logits = ops.convert_to_tensor(logits) processed_labels = self._process_labels(labels) - processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access - labels=processed_labels, logits=logits, - expected_labels_dimension=self.logits_dimension) if self._loss_fn: unweighted_loss = _call_loss_fn( loss_fn=self._loss_fn, labels=processed_labels, logits=logits, @@ -354,8 +290,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( unweighted_loss, axis=-1, keep_dims=True) - weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, - features=features, weight_column=self._weight_column, logits=logits) + weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access, weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -370,7 +305,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" with ops.name_scope(self._name, 'head'): - logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access + logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access # Predict. pred_keys = prediction_keys.PredictionKeys @@ -400,8 +335,6 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Eval. if mode == model_fn.ModeKeys.EVAL: - weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, - features=features, weight_column=self._weight_column, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, @@ -409,7 +342,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, - weights=weights, + weights=head_lib._weights(features, self._weight_column), # pylint:disable=protected-access, weighted_sum_loss=weighted_sum_loss, example_weight_sum=example_weight_sum)) |