diff options
author | 2017-03-06 19:47:17 -0800 | |
---|---|---|
committer | 2017-03-06 20:08:35 -0800 | |
commit | 3d725349272ca0a5f443ec631374a24474e5a513 (patch) | |
tree | 6bf629aa4de920d88e9fc4c349282c6a95aca351 | |
parent | 64809054c9083dbb66d632f47223d6c652d78d13 (diff) |
Allows users to provide custom loss function for multclass and multilabel problems.
Change: 149376544
-rw-r--r-- | tensorflow/contrib/learn/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 166 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head_test.py | 87 |
3 files changed, 196 insertions, 58 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 959b808d06..83904a77aa 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -602,6 +602,7 @@ py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:variables", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index b00c58b6e0..952cdeb5ec 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import abc +import inspect + import six from tensorflow.contrib import framework as framework_lib @@ -127,7 +129,8 @@ def _multi_class_head(n_classes, enable_centered_bias=False, head_name=None, thresholds=None, - metric_class_ids=None): + metric_class_ids=None, + loss_fn=None): """Creates a _Head for multi class single label classification. The Head uses softmax cross entropy loss. @@ -149,18 +152,25 @@ def _multi_class_head(n_classes, metric_class_ids: List of class IDs for which we should report per-class metrics. Must all be in the range `[0, n_classes)`. Invalid if `n_classes` is 2. + loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as + parameter and returns a weighted scalar loss. `weights` should be + optional. See `tf.losses` Returns: An instance of _MultiClassHead. Raises: - ValueError: if `n_classes` is < 2, or `metric_class_ids` is provided when - `n_classes` is 2. + ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when + `n_classes` is 2. + ValueError: If loss_fn does not have expected signature. """ if (n_classes is None) or (n_classes < 2): raise ValueError("n_classes must be > 1 for classification: %s." % n_classes) + if loss_fn: + _verify_loss_fn_args(loss_fn) + loss_fn = _wrap_custom_loss_fn(loss_fn) if loss_fn else None if n_classes == 2: if metric_class_ids: raise ValueError("metric_class_ids invalid for n_classes==2.") @@ -169,7 +179,8 @@ def _multi_class_head(n_classes, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias, head_name=head_name, - thresholds=thresholds) + thresholds=thresholds, + loss_fn=loss_fn) return _MultiClassHead( n_classes=n_classes, @@ -178,7 +189,8 @@ def _multi_class_head(n_classes, enable_centered_bias=enable_centered_bias, head_name=head_name, thresholds=thresholds, - metric_class_ids=metric_class_ids) + metric_class_ids=metric_class_ids, + loss_fn=loss_fn) def _binary_svm_head( @@ -223,7 +235,8 @@ def _multi_label_head(n_classes, enable_centered_bias=False, head_name=None, thresholds=None, - metric_class_ids=None): + metric_class_ids=None, + loss_fn=None): """Creates a _Head for multi label classification. The Head uses sigmoid cross entropy loss. @@ -244,15 +257,22 @@ def _multi_label_head(n_classes, thresholds: thresholds for eval metrics, defaults to [.5] metric_class_ids: List of class IDs for which we should report per-class metrics. Must all be in the range `[0, n_classes)`. + loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as + parameter and returns a weighted scalar loss. `weights` should be + optional. See `tf.losses` Returns: An instance of _MultiLabelHead. Raises: - ValueError: if n_classes is < 2 + ValueError: If n_classes is < 2 + ValueError: If loss_fn does not have expected signature. """ if n_classes < 2: raise ValueError("n_classes must be > 1 for classification.") + if loss_fn: + _verify_loss_fn_args(loss_fn) + return _MultiLabelHead( n_classes=n_classes, label_name=label_name, @@ -260,7 +280,8 @@ def _multi_label_head(n_classes, enable_centered_bias=enable_centered_bias, head_name=head_name, thresholds=thresholds, - metric_class_ids=metric_class_ids) + metric_class_ids=metric_class_ids, + loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) def _multi_head(heads, loss_weights=None): @@ -412,7 +433,7 @@ class _SingleHead(_Head): # TODO(zakaria): use contrib losses. -def _mean_squared_loss(labels, logits): +def _mean_squared_loss(labels, logits, weights=None): with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name: logits = ops.convert_to_tensor(logits) labels = ops.convert_to_tensor(labels) @@ -423,10 +444,11 @@ def _mean_squared_loss(labels, logits): if len(logits.get_shape()) == 1: logits = array_ops.expand_dims(logits, dim=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) - return math_ops.square(logits - math_ops.to_float(labels), name=name) + loss = math_ops.square(logits - math_ops.to_float(labels), name=name) + return _compute_weighted_loss(loss, weights) -def _poisson_loss(labels, logits): +def _poisson_loss(labels, logits, weights=None): """Computes poisson loss from logits.""" with ops.name_scope(None, "_poisson_loss", (logits, labels)) as name: logits = ops.convert_to_tensor(logits) @@ -438,8 +460,9 @@ def _poisson_loss(labels, logits): if len(logits.get_shape()) == 1: logits = array_ops.expand_dims(logits, dim=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) - return nn.log_poisson_loss(labels, logits, - compute_full_loss=True, name=name) + loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True, + name=name) + return _compute_weighted_loss(loss, weights) def _logits(logits_input, logits, logits_dimension): @@ -521,8 +544,7 @@ def _create_model_fn_ops(features, eval_metric_ops = None if (mode != model_fn.ModeKeys.INFER) and (labels is not None): weight_tensor = _weight_tensor(features, weight_column_name) - loss, weighted_average_loss = _loss( - loss_fn(labels, logits), weight_tensor) + loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor) logging_ops.scalar_summary( _summary_key(head_name, mkey.LOSS), weighted_average_loss) @@ -530,7 +552,7 @@ def _create_model_fn_ops(features, if train_op_fn is None: raise ValueError("train_op_fn can not be None in TRAIN mode") train_op = _train_op(loss, labels, train_op_fn, centered_bias, - logits_dimension, loss_fn) + logits_dimension, loss_fn, weight_tensor) eval_metric_ops = metrics_fn( weighted_average_loss, predictions, labels, weight_tensor) return model_fn.ModelFnOps( @@ -641,7 +663,7 @@ class _RegressionHead(_SingleHead): metrics_lib.streaming_mean(eval_loss)} -def _log_loss_with_two_classes(labels, logits): +def _log_loss_with_two_classes(labels, logits, weights=None): with ops.name_scope(None, "log_loss_with_two_classes", (logits, labels)) as name: logits = ops.convert_to_tensor(logits) @@ -650,8 +672,9 @@ def _log_loss_with_two_classes(labels, logits): # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. if len(labels.get_shape()) == 1: labels = array_ops.expand_dims(labels, dim=(1,)) - return nn.sigmoid_cross_entropy_with_logits( - labels=labels, logits=logits, name=name) + loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, + name=name) + return _compute_weighted_loss(loss, weights) def _one_class_to_two_class_logits(logits): @@ -666,7 +689,7 @@ class _BinaryLogisticHead(_SingleHead): weight_column_name=None, enable_centered_bias=False, head_name=None, - loss_fn=_log_loss_with_two_classes, + loss_fn=None, thresholds=None): """Base type for all single heads. @@ -695,7 +718,7 @@ class _BinaryLogisticHead(_SingleHead): weight_column_name=weight_column_name, head_name=head_name) self._thresholds = thresholds if thresholds else (.5,) - self._loss_fn = loss_fn + self._loss_fn = loss_fn if loss_fn else _log_loss_with_two_classes self._enable_centered_bias = enable_centered_bias def create_model_fn_ops(self, @@ -803,7 +826,7 @@ class _BinaryLogisticHead(_SingleHead): return metrics -def _softmax_cross_entropy_loss(labels, logits): +def _softmax_cross_entropy_loss(labels, logits, weights=None): with ops.name_scope( None, "softmax_cross_entropy_loss", (logits, labels,)) as name: labels = ops.convert_to_tensor(labels) @@ -815,8 +838,9 @@ def _softmax_cross_entropy_loss(labels, logits): # sparse_softmax_cross_entropy_with_logits requires [batch_size] labels. if len(labels.get_shape()) == 2: labels = array_ops.squeeze(labels, squeeze_dims=(1,)) - return nn.sparse_softmax_cross_entropy_with_logits( + loss = nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits, name=name) + return _compute_weighted_loss(loss, weights) class _MultiClassHead(_SingleHead): @@ -828,7 +852,7 @@ class _MultiClassHead(_SingleHead): weight_column_name=None, enable_centered_bias=False, head_name=None, - loss_fn=_softmax_cross_entropy_loss, + loss_fn=None, thresholds=None, metric_class_ids=None): """_Head for classification. @@ -865,7 +889,7 @@ class _MultiClassHead(_SingleHead): if (n_classes is None) or (n_classes <= 2): raise ValueError("n_classes must be > 2: %s." % n_classes) self._thresholds = thresholds if thresholds else (.5,) - self._loss_fn = loss_fn + self._loss_fn = loss_fn if loss_fn else _softmax_cross_entropy_loss self._enable_centered_bias = enable_centered_bias self._metric_class_ids = tuple([] if metric_class_ids is None else metric_class_ids) @@ -1020,11 +1044,12 @@ class _BinarySvmHead(_SingleHead): def __init__(self, label_name, weight_column_name, enable_centered_bias, head_name, thresholds): - def _loss_fn(labels, logits): + def _loss_fn(labels, logits, weights=None): with ops.name_scope(None, "hinge_loss", (logits, labels)) as name: with ops.control_dependencies((_assert_labels_rank(labels),)): labels = array_ops.reshape(labels, shape=(-1, 1)) - return losses_lib.hinge_loss(logits, labels, scope=name) + loss = losses_lib.hinge_loss(logits=logits, labels=labels, scope=name) + return _compute_weighted_loss(loss, weights) super(_BinarySvmHead, self).__init__( problem_type=constants.ProblemType.LOGISTIC_REGRESSION, @@ -1110,7 +1135,8 @@ class _MultiLabelHead(_SingleHead): enable_centered_bias, head_name, thresholds, - metric_class_ids=None): + metric_class_ids=None, + loss_fn=None): super(_MultiLabelHead, self).__init__( problem_type=constants.ProblemType.CLASSIFICATION, @@ -1120,7 +1146,7 @@ class _MultiLabelHead(_SingleHead): head_name=head_name) self._thresholds = thresholds if thresholds else (.5,) - self._loss_fn = _sigmoid_cross_entropy_loss + self._loss_fn = loss_fn if loss_fn else _sigmoid_cross_entropy_loss self._enable_centered_bias = enable_centered_bias self._metric_class_ids = tuple([] if metric_class_ids is None else metric_class_ids) @@ -1429,15 +1455,6 @@ class _MultiHead(_Head): eval_metric_ops=metrics) -def _weighted_loss(loss, weight): - """Returns cumulative weighted loss as 1d `Tensor`.""" - with ops.name_scope(None, "weighted_loss", (loss, weight)) as name: - return math_ops.multiply( - array_ops.reshape(loss, shape=(-1,)), - array_ops.reshape(weight, shape=(-1,)), - name=name) - - def _weight_tensor(features, weight_column_name): """Returns weights as 1d `Tensor`.""" if not weight_column_name: @@ -1447,8 +1464,10 @@ def _weight_tensor(features, weight_column_name): return math_ops.to_float(features[weight_column_name]) -def _loss(loss_unweighted, weight, name="loss"): - """Returns a tuple of (loss, weighted_average_loss). +# TODO(zakaria): This function is needed for backward compatibility and should +# be removed when we migrate to core. +def _compute_weighted_loss(loss_unweighted, weight, name="loss"): + """Returns a tuple of (loss_train, loss_report). loss is used for gradient descent while weighted_average_loss is used for summaries to be backward compatible. @@ -1467,21 +1486,36 @@ def _loss(loss_unweighted, weight, name="loss"): name: Optional name Returns: - A tuple of (loss, weighted_average_loss) + A tuple of losses. First one for training and the second one for reproting. """ with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope: if weight is None: loss = math_ops.reduce_mean(loss_unweighted, name=name_scope) return loss, loss - loss_weighted = _weighted_loss(loss_unweighted, weight) + with ops.name_scope(None, "weighted_loss", + (loss_unweighted, weight)) as name: + weighted_loss = math_ops.multiply( + array_ops.reshape(loss_unweighted, shape=(-1,)), + array_ops.reshape(weight, shape=(-1,)), name=name) # TODO(ptucker): This might be wrong if weights are broadcast to loss shape. # We should use tf.losses here. - weighted_average_loss = math_ops.div( - math_ops.reduce_sum(loss_weighted), + weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope) + weighted_loss_normalized = math_ops.div( + math_ops.reduce_sum(weighted_loss), math_ops.to_float(math_ops.reduce_sum(weight)), name="weighted_average_loss") - loss = math_ops.reduce_mean(loss_weighted, name=name_scope) - return loss, weighted_average_loss + + return weighted_loss_mean, weighted_loss_normalized + + +def _wrap_custom_loss_fn(loss_fn): + def _wrapper(labels, logits, weights=None): + if weights is None: + loss = loss_fn(labels, logits) + else: + loss = loss_fn(labels, logits, weights) + return loss, loss + return _wrapper def _check_mode_valid(mode): @@ -1491,6 +1525,26 @@ def _check_mode_valid(mode): raise ValueError("mode=%s unrecognized." % str(mode)) +def _get_arguments(func): + """Returns a spec of given func.""" + if hasattr(func, "__code__"): + # Regular function. + return inspect.getargspec(func) + elif hasattr(func, "__call__"): + # Callable object. + return _get_arguments(func.__call__) + elif hasattr(func, "func"): + # Partial function. + return _get_arguments(func.func) + + +def _verify_loss_fn_args(loss_fn): + args = _get_arguments(loss_fn).args + for arg_name in ["labels", "logits", "weights"]: + if arg_name not in args: + raise ValueError("Argument %s not found in loss_fn." % arg_name) + + def _centered_bias(logits_dimension, head_name=None): """Returns `logits`, optionally with centered bias applied. @@ -1522,7 +1576,8 @@ def _centered_bias(logits_dimension, head_name=None): return centered_bias -def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn): +def _centered_bias_step(centered_bias, logits_dimension, labels, + loss_fn, weights): """Creates and returns training op for centered bias.""" if (logits_dimension is None) or (logits_dimension < 1): raise ValueError("Invalid logits_dimension %s." % logits_dimension) @@ -1533,7 +1588,7 @@ def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn): (batch_size, logits_dimension)) with ops.name_scope(None, "centered_bias", (labels, logits)): centered_bias_loss = math_ops.reduce_mean( - loss_fn(labels, logits), name="training_loss") + loss_fn(labels, logits, weights), name="training_loss") # Learn central bias by an optimizer. 0.1 is a convervative lr for a # single variable. return training.AdagradOptimizer(0.1).minimize( @@ -1544,16 +1599,12 @@ def _summary_key(head_name, val): return "%s/%s" % (val, head_name) if head_name else val -def _train_op(loss, - labels, - train_op_fn, - centered_bias=None, - logits_dimension=None, - loss_fn=None): +def _train_op(loss, labels, train_op_fn, centered_bias, logits_dimension, + loss_fn, weights): """Returns op for the training step.""" if centered_bias is not None: centered_bias_step = _centered_bias_step(centered_bias, logits_dimension, - labels, loss_fn) + labels, loss_fn, weights) else: centered_bias_step = None with ops.name_scope(None, "train_op", (loss, labels)): @@ -1563,12 +1614,13 @@ def _train_op(loss, return train_op -def _sigmoid_cross_entropy_loss(labels, logits): +def _sigmoid_cross_entropy_loss(labels, logits, weights=None): with ops.name_scope(None, "sigmoid_cross_entropy_loss", (logits, labels)) as name: # sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels. - return nn.sigmoid_cross_entropy_with_logits( + loss = nn.sigmoid_cross_entropy_with_logits( labels=math_ops.to_float(labels), logits=logits, name=name) + return _compute_weighted_loss(loss, weights) def _float_weights_or_none(weights): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 52ac8992e5..faa3108caf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -33,6 +33,7 @@ from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.platform import test # pylint: enable=g-bad-todo,g-import-not-at-top @@ -507,6 +508,26 @@ class MultiLabelHeadTest(test.TestCase): _assert_metrics(self, .089985214, self._expected_eval_metrics(2.69956), model_fn_ops) + def testMultiLabelWithCustomLoss(self): + n_classes = 3 + head = head_lib._multi_label_head( + n_classes=n_classes, + weight_column_name="label_weight", + metric_class_ids=range(n_classes), + loss_fn=_sigmoid_cross_entropy) + with ops.Graph().as_default(), session.Session(): + model_fn_ops = head.create_model_fn_ops( + features={"label_weight": .1}, + labels=self._labels, + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn, + logits=self._logits) + self._assert_output_alternatives(model_fn_ops) + _assert_no_variables(self) + _assert_summary_tags(self, ["loss"]) + _assert_metrics(self, 0.089985214, + self._expected_eval_metrics(0.089985214), model_fn_ops) + def testMultiLabelWithCenteredBias(self): n_classes = 3 head = head_lib._multi_label_head( @@ -779,13 +800,49 @@ class BinaryClassificationHeadTest(test.TestCase): "auc": 0. / 1, "labels/actual_label_mean": 1. / 1, "labels/prediction_mean": .731059, # softmax - # TODO(ptucker): Is this the correct eval loss, sum not average? + # eval loss is weighted loss divided by sum of weights. "loss": expected_total_loss, "precision/positive_threshold_0.500000_mean": 1. / 1, "recall/positive_threshold_0.500000_mean": 1. / 1, }, model_fn_ops) + def testBinaryClassificationWithCustomLoss(self): + head = head_lib._multi_class_head( + n_classes=2, weight_column_name="label_weight", + loss_fn=_sigmoid_cross_entropy) + with ops.Graph().as_default(), session.Session(): + weights = ((.2,), (0.,)) + model_fn_ops = head.create_model_fn_ops( + features={"label_weight": weights}, + labels=self._labels, + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn, + logits=self._logits) + self._assert_output_alternatives(model_fn_ops) + _assert_no_variables(self) + _assert_summary_tags(self, ["loss"]) + # logloss: z:label, x:logit + # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) + # expected_loss is (total_weighted_loss)/1 since htere is 1 nonzero + # weight. + expected_loss = 0.062652342 + _assert_metrics( + self, + expected_loss, + { + "accuracy": 1. / 1, + "accuracy/baseline_label_mean": 1. / 1, + "accuracy/threshold_0.500000_mean": 1. / 1, + "auc": 0. / 1, + "labels/actual_label_mean": 1. / 1, + "labels/prediction_mean": .731059, # softmax + "loss": expected_loss, + "precision/positive_threshold_0.500000_mean": 1. / 1, + "recall/positive_threshold_0.500000_mean": 1. / 1, + }, + model_fn_ops) + def testBinaryClassificationWithCenteredBias(self): head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True) with ops.Graph().as_default(), session.Session(): @@ -1010,6 +1067,30 @@ class MultiClassHeadTest(test.TestCase): _assert_metrics(self, expected_loss * weight, self._expected_eval_metrics(expected_loss), model_fn_ops) + def testMultiClassWithCustomLoss(self): + n_classes = 3 + head = head_lib._multi_class_head( + n_classes=n_classes, + weight_column_name="label_weight", + metric_class_ids=range(n_classes), + loss_fn=losses_lib.sparse_softmax_cross_entropy) + with ops.Graph().as_default(), session.Session(): + weight = .1 + # logloss: z:label, x:logit + # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) + model_fn_ops = head.create_model_fn_ops( + features={"label_weight": weight}, + labels=self._labels, + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn, + logits=self._logits) + self._assert_output_alternatives(model_fn_ops) + _assert_no_variables(self) + _assert_summary_tags(self, ["loss"]) + expected_loss = 1.5514446 * weight + _assert_metrics(self, expected_loss, + self._expected_eval_metrics(expected_loss), model_fn_ops) + def testInvalidNClasses(self): for n_classes in (None, -1, 0, 1): with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"): @@ -1370,5 +1451,9 @@ class MultiHeadTest(test.TestCase): self.assertIn("accuracy/head2", metric_ops.keys()) +def _sigmoid_cross_entropy(labels, logits, weights): + return losses_lib.sigmoid_cross_entropy(labels, logits, weights) + + if __name__ == "__main__": test.main() |