diff options
author | 2016-11-22 16:39:19 -0800 | |
---|---|---|
committer | 2016-11-22 16:43:49 -0800 | |
commit | 8eff2d62f232abece68949af95efddc91689206f (patch) | |
tree | a29d3e5ee6e40882c7c7129b5d6a53cdd862220d | |
parent | c348cead5bcd07f62d48b834cb2b57866bbe7635 (diff) |
Split _BinaryLogisticHead from _MultiClassHead.
Change: 139971064
3 files changed, 285 insertions, 103 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index f7bb62e89e..30a822fdde 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -102,16 +102,18 @@ def _multi_class_head(n_classes, label_name=None, weight_column_name=None, Raises: ValueError: if n_classes is < 2 """ - if n_classes < 2: - raise ValueError("n_classes must be > 1 for classification.") + if (n_classes is None) or (n_classes < 2): + raise ValueError( + "n_classes must be > 1 for classification: %s." % n_classes) if n_classes == 2: - loss_fn = _log_loss_with_two_classes - else: - loss_fn = _softmax_cross_entropy_loss - return _MultiClassHead(train_loss_fn=loss_fn, - eval_loss_fn=loss_fn, - n_classes=n_classes, + return _BinaryLogisticHead(label_name=label_name, + weight_column_name=weight_column_name, + enable_centered_bias=enable_centered_bias, + head_name=head_name, + thresholds=thresholds) + + return _MultiClassHead(n_classes=n_classes, label_name=label_name, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, @@ -268,7 +270,6 @@ class _RegressionHead(_Head): """See `_Head`.""" _check_mode_valid(mode) _check_logits_input_not_supported(logits, logits_input) - predictions = self._predictions(logits) if (mode == model_fn.ModeKeys.INFER) or (labels is None): loss = None train_op = None @@ -278,15 +279,14 @@ class _RegressionHead(_Head): train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL else self._train_op(features, labels, train_op_fn, logits)) eval_metric_ops = self._eval_metric_ops(features, labels, logits) - signature_fn = self._signature_fn() return model_fn.ModelFnOps( mode=mode, - predictions=predictions, + predictions=self._predictions(logits), loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, - signature_fn=signature_fn) + signature_fn=self._signature_fn()) def _training_loss(self, features, labels, logits, name="training_loss"): """Returns training loss tensor for this head. @@ -403,18 +403,258 @@ class _RegressionHead(_Head): self._weight_column_name)} +def _log_loss_with_two_classes(logits, labels): + # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. + if len(labels.get_shape()) == 1: + labels = array_ops.expand_dims(labels, dim=[1]) + return nn.sigmoid_cross_entropy_with_logits( + logits, math_ops.to_float(labels)) + + +class _BinaryLogisticHead(_Head): + """_Head for binary logistic classifciation.""" + + def __init__(self, label_name, weight_column_name, enable_centered_bias, + head_name, loss_fn=_log_loss_with_two_classes, thresholds=None): + """Base type for all single heads. + + Args: + label_name: String, name of the key in label dict. Can be null if label + is a tensor (single headed models). + weight_column_name: A string defining feature column name representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + enable_centered_bias: A bool. If True, estimator will learn a centered + bias variable for each class. Rest of the model structure learns the + residual after centered bias. + head_name: name of the head. If provided, predictions, summary and metrics + keys will be prefixed by the head_name and an underscore. + loss_fn: Loss function. + thresholds: thresholds for eval. + + Raises: + ValueError: if n_classes is invalid. + """ + self._thresholds = thresholds if thresholds else [.5] + self._label_name = label_name + self._weight_column_name = weight_column_name + self._head_name = head_name + self._loss_fn = loss_fn + self._enable_centered_bias = enable_centered_bias + self._centered_bias_weight_collection = _head_prefixed(head_name, + "centered_bias") + + @property + def logits_dimension(self): + return 1 + + def head_ops(self, features, labels, mode, train_op_fn, logits=None, + logits_input=None): + """See `_Head`.""" + _check_mode_valid(mode) + _check_logits_input_not_supported(logits, logits_input) + if (mode == model_fn.ModeKeys.INFER) or (labels is None): + loss = None + train_op = None + eval_metric_ops = None + else: + loss = self._training_loss(features, labels, logits) + train_op = (None if train_op_fn is None + else self._train_op(features, labels, train_op_fn, logits)) + eval_metric_ops = self._eval_metric_ops(features, labels, logits) + + return model_fn.ModelFnOps( + mode=mode, + predictions=self._predictions(logits), + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + signature_fn=self._signature_fn()) + + def _training_loss(self, features, labels, logits=None, name="training_loss"): + """Returns training loss tensor for this head. + + Training loss is different from the loss reported on the tensorboard as we + should respect the example weights when computing the gradient. + + L = sum_{i} w_{i} * l_{i} / B + + where B is the number of examples in the batch, l_{i}, w_{i} are individual + losses, and example weight. + + Args: + features: features dict. + labels: either a tensor for labels or in multihead case, a dict of string + to labels tensor. + logits: logits, a float tensor. + name: Op name. + + Returns: + A loss `Output`. + """ + labels = _check_labels(labels, self._label_name) + + if self._enable_centered_bias: + logits = nn.bias_add(logits, _centered_bias( + self.logits_dimension, + self._centered_bias_weight_collection)) + + loss_unweighted = self._loss_fn(logits, labels) + loss, weighted_average_loss = _loss( + loss_unweighted, + _weight_tensor(features, self._weight_column_name), + name=name) + summary.scalar( + _head_prefixed(self._head_name, "loss"), weighted_average_loss) + return loss + + def _train_op(self, features, labels, train_op_fn, logits): + """Returns op for the training step.""" + loss = self._training_loss(features, labels, logits) + train_op = train_op_fn(loss) + + if self._enable_centered_bias: + centered_bias_step = [_centered_bias_step( + self.logits_dimension, + self._centered_bias_weight_collection, + labels, + self._loss_fn)] + train_op = control_flow_ops.group(train_op, *centered_bias_step) + + return train_op + + def _eval_metric_ops(self, features, labels, logits): + """Returns a dict of metric ops keyed by name.""" + labels = _check_labels(labels, self._label_name) + predictions = self._predictions(logits) + return estimator._make_metrics_ops( # pylint: disable=protected-access + self._default_metrics(), features, labels, predictions) + + def _predictions(self, logits): + """Returns a dict of predictions. + + Args: + logits: logits `Output` before applying possible centered bias. + + Returns: + Dict of prediction `Output` keyed by `PredictionKey`. + """ + if self._enable_centered_bias: + logits = nn.bias_add(logits, _centered_bias( + self.logits_dimension, + self._centered_bias_weight_collection)) + return self._logits_to_predictions(logits) + + def _logits_to_predictions(self, logits): + """Returns a dict of predictions. + + Args: + logits: logits `Output` after applying possible centered bias. + + Returns: + Dict of prediction `Output` keyed by `PredictionKey`. + """ + predictions = {prediction_key.PredictionKey.LOGITS: logits} + predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( + logits) + logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) + predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax( + logits) + predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax( + logits, 1) + return predictions + + def _signature_fn(self): + """Returns the signature_fn to be used in exporting.""" + def _classification_signature_fn(examples, unused_features, predictions): + """Servo signature function.""" + if isinstance(predictions, dict): + default_signature = exporter.classification_signature( + input_tensor=examples, + classes_tensor=predictions[prediction_key.PredictionKey.CLASSES], + scores_tensor=predictions[ + prediction_key.PredictionKey.PROBABILITIES]) + else: + default_signature = exporter.classification_signature( + input_tensor=examples, + scores_tensor=predictions) + + # TODO(zakaria): add validation + return default_signature, {} + return _classification_signature_fn + + def _default_metrics(self): + """Returns a dict of `MetricSpec` objects keyed by name.""" + metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): + _weighted_average_loss_metric_spec( + self._loss_fn, + prediction_key.PredictionKey.LOGITS, + self._label_name, + self._weight_column_name)} + + # TODO(b/29366811): This currently results in both an "accuracy" and an + # "accuracy/threshold_0.500000_mean" metric for binary classification. + metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = ( + metric_spec.MetricSpec(metrics_lib.streaming_accuracy, + prediction_key.PredictionKey.CLASSES, + self._label_name, + self._weight_column_name)) + def _add_binary_metric(key, metric_fn): + metrics[_head_prefixed(self._head_name, key)] = ( + metric_spec.MetricSpec(metric_fn, + prediction_key.PredictionKey.LOGISTIC, + self._label_name, + self._weight_column_name)) + _add_binary_metric( + metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean) + _add_binary_metric( + metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean) + + # Also include the streaming mean of the label as an accuracy baseline, as + # a reminder to users. + _add_binary_metric( + metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean) + + _add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc) + + for threshold in self._thresholds: + _add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold, + _accuracy_at_threshold(threshold)) + # Precision for positive examples. + _add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold, + _streaming_at_threshold( + metrics_lib.streaming_precision_at_thresholds, + threshold),) + # Recall for positive examples. + _add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold, + _streaming_at_threshold( + metrics_lib.streaming_recall_at_thresholds, + threshold)) + return metrics + + +def _softmax_cross_entropy_loss(logits, labels): + # Check that we got integer for classification. + if not labels.dtype.is_integer: + raise ValueError("Labels dtype should be integer " + "Instead got %s." % labels.dtype) + # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. + if len(labels.get_shape()) == 2: + labels = array_ops.squeeze(labels, squeeze_dims=[1]) + return nn.sparse_softmax_cross_entropy_with_logits(logits, labels) + + class _MultiClassHead(_Head): """_Head for classification.""" - def __init__(self, train_loss_fn, eval_loss_fn, n_classes, label_name, + def __init__(self, n_classes, label_name, weight_column_name, enable_centered_bias, head_name, - thresholds=None): + loss_fn=_softmax_cross_entropy_loss, thresholds=None): """Base type for all single heads. Args: - train_loss_fn: loss_fn for training. - eval_loss_fn: loss_fn for eval. - n_classes: number of classes. + n_classes: Number of classes, must be greater than 2 (for 2 classes, use + `_BinaryLogisticHead`). label_name: String, name of the key in label dict. Can be null if label is a tensor (single headed models). weight_column_name: A string defining feature column name representing @@ -425,21 +665,21 @@ class _MultiClassHead(_Head): residual after centered bias. head_name: name of the head. If provided, predictions, summary and metrics keys will be prefixed by the head_name and an underscore. + loss_fn: Loss function. thresholds: thresholds for eval. Raises: ValueError: if n_classes is invalid. """ - if n_classes < 2: - raise ValueError("n_classes must be >= 2") + if (n_classes is None) or (n_classes <= 2): + raise ValueError("n_classes must be > 2: %s." % n_classes) self._thresholds = thresholds if thresholds else [.5] - self._train_loss_fn = train_loss_fn - self._eval_loss_fn = eval_loss_fn - self._logits_dimension = 1 if n_classes == 2 else n_classes + self._logits_dimension = n_classes self._label_name = label_name self._weight_column_name = weight_column_name self._head_name = head_name + self._loss_fn = loss_fn self._enable_centered_bias = enable_centered_bias self._centered_bias_weight_collection = _head_prefixed(head_name, "centered_bias") @@ -453,7 +693,6 @@ class _MultiClassHead(_Head): """See `_Head`.""" _check_mode_valid(mode) _check_logits_input_not_supported(logits, logits_input) - predictions = self._predictions(logits) if (mode == model_fn.ModeKeys.INFER) or (labels is None): loss = None train_op = None @@ -463,15 +702,14 @@ class _MultiClassHead(_Head): train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL else self._train_op(features, labels, train_op_fn, logits)) eval_metric_ops = self._eval_metric_ops(features, labels, logits) - signature_fn = self._signature_fn() return model_fn.ModelFnOps( mode=mode, - predictions=predictions, + predictions=self._predictions(logits), loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, - signature_fn=signature_fn) + signature_fn=self._signature_fn()) def _training_loss(self, features, labels, logits=None, name="training_loss"): """Returns training loss tensor for this head. @@ -501,7 +739,7 @@ class _MultiClassHead(_Head): self.logits_dimension, self._centered_bias_weight_collection)) - loss_unweighted = self._train_loss_fn(logits, labels) + loss_unweighted = self._loss_fn(logits, labels) loss, weighted_average_loss = _loss( loss_unweighted, _weight_tensor(features, self._weight_column_name), @@ -520,7 +758,7 @@ class _MultiClassHead(_Head): self.logits_dimension, self._centered_bias_weight_collection, labels, - self._train_loss_fn)] + self._loss_fn)] train_op = control_flow_ops.group(train_op, *centered_bias_step) return train_op @@ -557,10 +795,6 @@ class _MultiClassHead(_Head): Dict of prediction `Tensor` keyed by `PredictionKey`. """ predictions = {prediction_key.PredictionKey.LOGITS: logits} - if self.logits_dimension == 1: - predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( - logits) - logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax( logits) predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax( @@ -591,7 +825,7 @@ class _MultiClassHead(_Head): """Returns a dict of `MetricSpec` objects keyed by name.""" metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): _weighted_average_loss_metric_spec( - self._eval_loss_fn, + self._loss_fn, prediction_key.PredictionKey.LOGITS, self._label_name, self._weight_column_name)} @@ -603,38 +837,9 @@ class _MultiClassHead(_Head): prediction_key.PredictionKey.CLASSES, self._label_name, self._weight_column_name)) - if self.logits_dimension == 1: - def _add_binary_metric(key, metric_fn): - metrics[_head_prefixed(self._head_name, key)] = ( - metric_spec.MetricSpec(metric_fn, - prediction_key.PredictionKey.LOGISTIC, - self._label_name, - self._weight_column_name)) - _add_binary_metric( - metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean) - _add_binary_metric( - metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean) - - # Also include the streaming mean of the label as an accuracy baseline, as - # a reminder to users. - _add_binary_metric( - metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean) - - _add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc) - - for threshold in self._thresholds: - _add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold, - _accuracy_at_threshold(threshold)) - # Precision for positive examples. - _add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold, - _streaming_at_threshold( - metrics_lib.streaming_precision_at_thresholds, - threshold),) - # Recall for positive examples. - _add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold, - _streaming_at_threshold( - metrics_lib.streaming_recall_at_thresholds, - threshold)) + + # TODO(b/32953199): Add multiclass metrics. + return metrics @@ -645,12 +850,12 @@ def _check_labels(labels, label_name): return labels -class _BinarySvmHead(_MultiClassHead): +class _BinarySvmHead(_BinaryLogisticHead): """_Head for binary classification using SVMs.""" def __init__(self, label_name, weight_column_name, enable_centered_bias, head_name, thresholds): - def loss_fn(logits, labels): + def _loss_fn(logits, labels): check_shape_op = control_flow_ops.Assert( math_ops.less_equal(array_ops.rank(labels), 2), ["labels shape should be either [batch_size, 1] or [batch_size]"]) @@ -660,9 +865,7 @@ class _BinarySvmHead(_MultiClassHead): return losses.hinge_loss(logits, labels) super(_BinarySvmHead, self).__init__( - train_loss_fn=loss_fn, - eval_loss_fn=loss_fn, - n_classes=2, + loss_fn=_loss_fn, label_name=label_name, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, @@ -683,7 +886,7 @@ class _BinarySvmHead(_MultiClassHead): """See `_MultiClassHead`.""" metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): _weighted_average_loss_metric_spec( - self._eval_loss_fn, + self._loss_fn, prediction_key.PredictionKey.LOGITS, self._label_name, self._weight_column_name)} @@ -821,27 +1024,6 @@ def _mean_squared_loss(logits, labels): return math_ops.square(logits - math_ops.to_float(labels)) -def _log_loss_with_two_classes(logits, labels): - # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. - if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=[1]) - loss_vec = nn.sigmoid_cross_entropy_with_logits(logits, - math_ops.to_float(labels)) - return loss_vec - - -def _softmax_cross_entropy_loss(logits, labels): - # Check that we got integer for classification. - if not labels.dtype.is_integer: - raise ValueError("Labels dtype should be integer " - "Instead got %s." % labels.dtype) - # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. - if len(labels.get_shape()) == 2: - labels = array_ops.squeeze(labels, squeeze_dims=[1]) - loss_vec = nn.sparse_softmax_cross_entropy_with_logits(logits, labels) - return loss_vec - - def _sigmoid_cross_entropy_loss(logits, labels): # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. return nn.sigmoid_cross_entropy_with_logits(logits, math_ops.to_float(labels)) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 26cf8d372e..5786376314 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -131,13 +131,10 @@ class MultiClassModelHeadTest(tf.test.TestCase): _noop_train_op, logits=logits) self.assertAlmostEqual(.15514446, sess.run(model_fn_ops.loss)) - def testMultiClassWithInvalidNClass(self): - try: - head_lib._multi_class_head(n_classes=1) - self.fail("Softmax with no n_classes did not raise error.") - except ValueError: - # Expected - pass + def testInvalidNClasses(self): + for n_classes in (None, -1, 0, 1): + with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"): + head_lib._multi_class_head(n_classes=n_classes) class BinarySvmModelHeadTest(tf.test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index a1175c327d..763d70ddaf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -196,14 +196,17 @@ def sdca_model_fn(features, labels, mode, params): if not isinstance(optimizer, sdca_optimizer.SDCAOptimizer): raise ValueError("Optimizer must be of type SDCAOptimizer") - if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access + # pylint: disable=protected-access + if isinstance(head, head_lib._BinarySvmHead): loss_type = "hinge_loss" - elif isinstance(head, head_lib._MultiClassHead): # pylint: disable=protected-access + elif isinstance( + head, (head_lib._MultiClassHead, head_lib._BinaryLogisticHead)): loss_type = "logistic_loss" - elif isinstance(head, head_lib._RegressionHead): # pylint: disable=protected-access + elif isinstance(head, head_lib._RegressionHead): loss_type = "squared_loss" else: - return ValueError("Unsupported head type: {}".format(head)) + raise ValueError("Unsupported head type: {}".format(head)) + # pylint: enable=protected-access parent_scope = "linear" |