aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-22 16:39:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-22 16:43:49 -0800
commit8eff2d62f232abece68949af95efddc91689206f (patch)
treea29d3e5ee6e40882c7c7129b5d6a53cdd862220d
parentc348cead5bcd07f62d48b834cb2b57866bbe7635 (diff)
Split _BinaryLogisticHead from _MultiClassHead.
Change: 139971064
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py366
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py11
3 files changed, 285 insertions, 103 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index f7bb62e89e..30a822fdde 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -102,16 +102,18 @@ def _multi_class_head(n_classes, label_name=None, weight_column_name=None,
Raises:
ValueError: if n_classes is < 2
"""
- if n_classes < 2:
- raise ValueError("n_classes must be > 1 for classification.")
+ if (n_classes is None) or (n_classes < 2):
+ raise ValueError(
+ "n_classes must be > 1 for classification: %s." % n_classes)
if n_classes == 2:
- loss_fn = _log_loss_with_two_classes
- else:
- loss_fn = _softmax_cross_entropy_loss
- return _MultiClassHead(train_loss_fn=loss_fn,
- eval_loss_fn=loss_fn,
- n_classes=n_classes,
+ return _BinaryLogisticHead(label_name=label_name,
+ weight_column_name=weight_column_name,
+ enable_centered_bias=enable_centered_bias,
+ head_name=head_name,
+ thresholds=thresholds)
+
+ return _MultiClassHead(n_classes=n_classes,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
@@ -268,7 +270,6 @@ class _RegressionHead(_Head):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
- predictions = self._predictions(logits)
if (mode == model_fn.ModeKeys.INFER) or (labels is None):
loss = None
train_op = None
@@ -278,15 +279,14 @@ class _RegressionHead(_Head):
train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
else self._train_op(features, labels, train_op_fn, logits))
eval_metric_ops = self._eval_metric_ops(features, labels, logits)
- signature_fn = self._signature_fn()
return model_fn.ModelFnOps(
mode=mode,
- predictions=predictions,
+ predictions=self._predictions(logits),
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
- signature_fn=signature_fn)
+ signature_fn=self._signature_fn())
def _training_loss(self, features, labels, logits, name="training_loss"):
"""Returns training loss tensor for this head.
@@ -403,18 +403,258 @@ class _RegressionHead(_Head):
self._weight_column_name)}
+def _log_loss_with_two_classes(logits, labels):
+ # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
+ if len(labels.get_shape()) == 1:
+ labels = array_ops.expand_dims(labels, dim=[1])
+ return nn.sigmoid_cross_entropy_with_logits(
+ logits, math_ops.to_float(labels))
+
+
+class _BinaryLogisticHead(_Head):
+ """_Head for binary logistic classifciation."""
+
+ def __init__(self, label_name, weight_column_name, enable_centered_bias,
+ head_name, loss_fn=_log_loss_with_two_classes, thresholds=None):
+ """Base type for all single heads.
+
+ Args:
+ label_name: String, name of the key in label dict. Can be null if label
+ is a tensor (single headed models).
+ weight_column_name: A string defining feature column name representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example.
+ enable_centered_bias: A bool. If True, estimator will learn a centered
+ bias variable for each class. Rest of the model structure learns the
+ residual after centered bias.
+ head_name: name of the head. If provided, predictions, summary and metrics
+ keys will be prefixed by the head_name and an underscore.
+ loss_fn: Loss function.
+ thresholds: thresholds for eval.
+
+ Raises:
+ ValueError: if n_classes is invalid.
+ """
+ self._thresholds = thresholds if thresholds else [.5]
+ self._label_name = label_name
+ self._weight_column_name = weight_column_name
+ self._head_name = head_name
+ self._loss_fn = loss_fn
+ self._enable_centered_bias = enable_centered_bias
+ self._centered_bias_weight_collection = _head_prefixed(head_name,
+ "centered_bias")
+
+ @property
+ def logits_dimension(self):
+ return 1
+
+ def head_ops(self, features, labels, mode, train_op_fn, logits=None,
+ logits_input=None):
+ """See `_Head`."""
+ _check_mode_valid(mode)
+ _check_logits_input_not_supported(logits, logits_input)
+ if (mode == model_fn.ModeKeys.INFER) or (labels is None):
+ loss = None
+ train_op = None
+ eval_metric_ops = None
+ else:
+ loss = self._training_loss(features, labels, logits)
+ train_op = (None if train_op_fn is None
+ else self._train_op(features, labels, train_op_fn, logits))
+ eval_metric_ops = self._eval_metric_ops(features, labels, logits)
+
+ return model_fn.ModelFnOps(
+ mode=mode,
+ predictions=self._predictions(logits),
+ loss=loss,
+ train_op=train_op,
+ eval_metric_ops=eval_metric_ops,
+ signature_fn=self._signature_fn())
+
+ def _training_loss(self, features, labels, logits=None, name="training_loss"):
+ """Returns training loss tensor for this head.
+
+ Training loss is different from the loss reported on the tensorboard as we
+ should respect the example weights when computing the gradient.
+
+ L = sum_{i} w_{i} * l_{i} / B
+
+ where B is the number of examples in the batch, l_{i}, w_{i} are individual
+ losses, and example weight.
+
+ Args:
+ features: features dict.
+ labels: either a tensor for labels or in multihead case, a dict of string
+ to labels tensor.
+ logits: logits, a float tensor.
+ name: Op name.
+
+ Returns:
+ A loss `Output`.
+ """
+ labels = _check_labels(labels, self._label_name)
+
+ if self._enable_centered_bias:
+ logits = nn.bias_add(logits, _centered_bias(
+ self.logits_dimension,
+ self._centered_bias_weight_collection))
+
+ loss_unweighted = self._loss_fn(logits, labels)
+ loss, weighted_average_loss = _loss(
+ loss_unweighted,
+ _weight_tensor(features, self._weight_column_name),
+ name=name)
+ summary.scalar(
+ _head_prefixed(self._head_name, "loss"), weighted_average_loss)
+ return loss
+
+ def _train_op(self, features, labels, train_op_fn, logits):
+ """Returns op for the training step."""
+ loss = self._training_loss(features, labels, logits)
+ train_op = train_op_fn(loss)
+
+ if self._enable_centered_bias:
+ centered_bias_step = [_centered_bias_step(
+ self.logits_dimension,
+ self._centered_bias_weight_collection,
+ labels,
+ self._loss_fn)]
+ train_op = control_flow_ops.group(train_op, *centered_bias_step)
+
+ return train_op
+
+ def _eval_metric_ops(self, features, labels, logits):
+ """Returns a dict of metric ops keyed by name."""
+ labels = _check_labels(labels, self._label_name)
+ predictions = self._predictions(logits)
+ return estimator._make_metrics_ops( # pylint: disable=protected-access
+ self._default_metrics(), features, labels, predictions)
+
+ def _predictions(self, logits):
+ """Returns a dict of predictions.
+
+ Args:
+ logits: logits `Output` before applying possible centered bias.
+
+ Returns:
+ Dict of prediction `Output` keyed by `PredictionKey`.
+ """
+ if self._enable_centered_bias:
+ logits = nn.bias_add(logits, _centered_bias(
+ self.logits_dimension,
+ self._centered_bias_weight_collection))
+ return self._logits_to_predictions(logits)
+
+ def _logits_to_predictions(self, logits):
+ """Returns a dict of predictions.
+
+ Args:
+ logits: logits `Output` after applying possible centered bias.
+
+ Returns:
+ Dict of prediction `Output` keyed by `PredictionKey`.
+ """
+ predictions = {prediction_key.PredictionKey.LOGITS: logits}
+ predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
+ logits)
+ logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
+ predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
+ logits)
+ predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
+ logits, 1)
+ return predictions
+
+ def _signature_fn(self):
+ """Returns the signature_fn to be used in exporting."""
+ def _classification_signature_fn(examples, unused_features, predictions):
+ """Servo signature function."""
+ if isinstance(predictions, dict):
+ default_signature = exporter.classification_signature(
+ input_tensor=examples,
+ classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
+ scores_tensor=predictions[
+ prediction_key.PredictionKey.PROBABILITIES])
+ else:
+ default_signature = exporter.classification_signature(
+ input_tensor=examples,
+ scores_tensor=predictions)
+
+ # TODO(zakaria): add validation
+ return default_signature, {}
+ return _classification_signature_fn
+
+ def _default_metrics(self):
+ """Returns a dict of `MetricSpec` objects keyed by name."""
+ metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
+ _weighted_average_loss_metric_spec(
+ self._loss_fn,
+ prediction_key.PredictionKey.LOGITS,
+ self._label_name,
+ self._weight_column_name)}
+
+ # TODO(b/29366811): This currently results in both an "accuracy" and an
+ # "accuracy/threshold_0.500000_mean" metric for binary classification.
+ metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
+ metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
+ prediction_key.PredictionKey.CLASSES,
+ self._label_name,
+ self._weight_column_name))
+ def _add_binary_metric(key, metric_fn):
+ metrics[_head_prefixed(self._head_name, key)] = (
+ metric_spec.MetricSpec(metric_fn,
+ prediction_key.PredictionKey.LOGISTIC,
+ self._label_name,
+ self._weight_column_name))
+ _add_binary_metric(
+ metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
+ _add_binary_metric(
+ metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
+
+ # Also include the streaming mean of the label as an accuracy baseline, as
+ # a reminder to users.
+ _add_binary_metric(
+ metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
+
+ _add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
+
+ for threshold in self._thresholds:
+ _add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
+ _accuracy_at_threshold(threshold))
+ # Precision for positive examples.
+ _add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
+ _streaming_at_threshold(
+ metrics_lib.streaming_precision_at_thresholds,
+ threshold),)
+ # Recall for positive examples.
+ _add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
+ _streaming_at_threshold(
+ metrics_lib.streaming_recall_at_thresholds,
+ threshold))
+ return metrics
+
+
+def _softmax_cross_entropy_loss(logits, labels):
+ # Check that we got integer for classification.
+ if not labels.dtype.is_integer:
+ raise ValueError("Labels dtype should be integer "
+ "Instead got %s." % labels.dtype)
+ # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
+ if len(labels.get_shape()) == 2:
+ labels = array_ops.squeeze(labels, squeeze_dims=[1])
+ return nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
+
+
class _MultiClassHead(_Head):
"""_Head for classification."""
- def __init__(self, train_loss_fn, eval_loss_fn, n_classes, label_name,
+ def __init__(self, n_classes, label_name,
weight_column_name, enable_centered_bias, head_name,
- thresholds=None):
+ loss_fn=_softmax_cross_entropy_loss, thresholds=None):
"""Base type for all single heads.
Args:
- train_loss_fn: loss_fn for training.
- eval_loss_fn: loss_fn for eval.
- n_classes: number of classes.
+ n_classes: Number of classes, must be greater than 2 (for 2 classes, use
+ `_BinaryLogisticHead`).
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
@@ -425,21 +665,21 @@ class _MultiClassHead(_Head):
residual after centered bias.
head_name: name of the head. If provided, predictions, summary and metrics
keys will be prefixed by the head_name and an underscore.
+ loss_fn: Loss function.
thresholds: thresholds for eval.
Raises:
ValueError: if n_classes is invalid.
"""
- if n_classes < 2:
- raise ValueError("n_classes must be >= 2")
+ if (n_classes is None) or (n_classes <= 2):
+ raise ValueError("n_classes must be > 2: %s." % n_classes)
self._thresholds = thresholds if thresholds else [.5]
- self._train_loss_fn = train_loss_fn
- self._eval_loss_fn = eval_loss_fn
- self._logits_dimension = 1 if n_classes == 2 else n_classes
+ self._logits_dimension = n_classes
self._label_name = label_name
self._weight_column_name = weight_column_name
self._head_name = head_name
+ self._loss_fn = loss_fn
self._enable_centered_bias = enable_centered_bias
self._centered_bias_weight_collection = _head_prefixed(head_name,
"centered_bias")
@@ -453,7 +693,6 @@ class _MultiClassHead(_Head):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
- predictions = self._predictions(logits)
if (mode == model_fn.ModeKeys.INFER) or (labels is None):
loss = None
train_op = None
@@ -463,15 +702,14 @@ class _MultiClassHead(_Head):
train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
else self._train_op(features, labels, train_op_fn, logits))
eval_metric_ops = self._eval_metric_ops(features, labels, logits)
- signature_fn = self._signature_fn()
return model_fn.ModelFnOps(
mode=mode,
- predictions=predictions,
+ predictions=self._predictions(logits),
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
- signature_fn=signature_fn)
+ signature_fn=self._signature_fn())
def _training_loss(self, features, labels, logits=None, name="training_loss"):
"""Returns training loss tensor for this head.
@@ -501,7 +739,7 @@ class _MultiClassHead(_Head):
self.logits_dimension,
self._centered_bias_weight_collection))
- loss_unweighted = self._train_loss_fn(logits, labels)
+ loss_unweighted = self._loss_fn(logits, labels)
loss, weighted_average_loss = _loss(
loss_unweighted,
_weight_tensor(features, self._weight_column_name),
@@ -520,7 +758,7 @@ class _MultiClassHead(_Head):
self.logits_dimension,
self._centered_bias_weight_collection,
labels,
- self._train_loss_fn)]
+ self._loss_fn)]
train_op = control_flow_ops.group(train_op, *centered_bias_step)
return train_op
@@ -557,10 +795,6 @@ class _MultiClassHead(_Head):
Dict of prediction `Tensor` keyed by `PredictionKey`.
"""
predictions = {prediction_key.PredictionKey.LOGITS: logits}
- if self.logits_dimension == 1:
- predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
- logits)
- logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
logits)
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
@@ -591,7 +825,7 @@ class _MultiClassHead(_Head):
"""Returns a dict of `MetricSpec` objects keyed by name."""
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
- self._eval_loss_fn,
+ self._loss_fn,
prediction_key.PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
@@ -603,38 +837,9 @@ class _MultiClassHead(_Head):
prediction_key.PredictionKey.CLASSES,
self._label_name,
self._weight_column_name))
- if self.logits_dimension == 1:
- def _add_binary_metric(key, metric_fn):
- metrics[_head_prefixed(self._head_name, key)] = (
- metric_spec.MetricSpec(metric_fn,
- prediction_key.PredictionKey.LOGISTIC,
- self._label_name,
- self._weight_column_name))
- _add_binary_metric(
- metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
- _add_binary_metric(
- metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
-
- # Also include the streaming mean of the label as an accuracy baseline, as
- # a reminder to users.
- _add_binary_metric(
- metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
-
- _add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
-
- for threshold in self._thresholds:
- _add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
- _accuracy_at_threshold(threshold))
- # Precision for positive examples.
- _add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
- _streaming_at_threshold(
- metrics_lib.streaming_precision_at_thresholds,
- threshold),)
- # Recall for positive examples.
- _add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
- _streaming_at_threshold(
- metrics_lib.streaming_recall_at_thresholds,
- threshold))
+
+ # TODO(b/32953199): Add multiclass metrics.
+
return metrics
@@ -645,12 +850,12 @@ def _check_labels(labels, label_name):
return labels
-class _BinarySvmHead(_MultiClassHead):
+class _BinarySvmHead(_BinaryLogisticHead):
"""_Head for binary classification using SVMs."""
def __init__(self, label_name, weight_column_name, enable_centered_bias,
head_name, thresholds):
- def loss_fn(logits, labels):
+ def _loss_fn(logits, labels):
check_shape_op = control_flow_ops.Assert(
math_ops.less_equal(array_ops.rank(labels), 2),
["labels shape should be either [batch_size, 1] or [batch_size]"])
@@ -660,9 +865,7 @@ class _BinarySvmHead(_MultiClassHead):
return losses.hinge_loss(logits, labels)
super(_BinarySvmHead, self).__init__(
- train_loss_fn=loss_fn,
- eval_loss_fn=loss_fn,
- n_classes=2,
+ loss_fn=_loss_fn,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
@@ -683,7 +886,7 @@ class _BinarySvmHead(_MultiClassHead):
"""See `_MultiClassHead`."""
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
- self._eval_loss_fn,
+ self._loss_fn,
prediction_key.PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
@@ -821,27 +1024,6 @@ def _mean_squared_loss(logits, labels):
return math_ops.square(logits - math_ops.to_float(labels))
-def _log_loss_with_two_classes(logits, labels):
- # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
- if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, dim=[1])
- loss_vec = nn.sigmoid_cross_entropy_with_logits(logits,
- math_ops.to_float(labels))
- return loss_vec
-
-
-def _softmax_cross_entropy_loss(logits, labels):
- # Check that we got integer for classification.
- if not labels.dtype.is_integer:
- raise ValueError("Labels dtype should be integer "
- "Instead got %s." % labels.dtype)
- # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
- if len(labels.get_shape()) == 2:
- labels = array_ops.squeeze(labels, squeeze_dims=[1])
- loss_vec = nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
- return loss_vec
-
-
def _sigmoid_cross_entropy_loss(logits, labels):
# sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
return nn.sigmoid_cross_entropy_with_logits(logits, math_ops.to_float(labels))
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 26cf8d372e..5786376314 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -131,13 +131,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
_noop_train_op, logits=logits)
self.assertAlmostEqual(.15514446, sess.run(model_fn_ops.loss))
- def testMultiClassWithInvalidNClass(self):
- try:
- head_lib._multi_class_head(n_classes=1)
- self.fail("Softmax with no n_classes did not raise error.")
- except ValueError:
- # Expected
- pass
+ def testInvalidNClasses(self):
+ for n_classes in (None, -1, 0, 1):
+ with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
+ head_lib._multi_class_head(n_classes=n_classes)
class BinarySvmModelHeadTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index a1175c327d..763d70ddaf 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -196,14 +196,17 @@ def sdca_model_fn(features, labels, mode, params):
if not isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
raise ValueError("Optimizer must be of type SDCAOptimizer")
- if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ if isinstance(head, head_lib._BinarySvmHead):
loss_type = "hinge_loss"
- elif isinstance(head, head_lib._MultiClassHead): # pylint: disable=protected-access
+ elif isinstance(
+ head, (head_lib._MultiClassHead, head_lib._BinaryLogisticHead)):
loss_type = "logistic_loss"
- elif isinstance(head, head_lib._RegressionHead): # pylint: disable=protected-access
+ elif isinstance(head, head_lib._RegressionHead):
loss_type = "squared_loss"
else:
- return ValueError("Unsupported head type: {}".format(head))
+ raise ValueError("Unsupported head type: {}".format(head))
+ # pylint: enable=protected-access
parent_scope = "linear"