aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Zakaria Haque <zakaria@google.com>2017-03-06 19:47:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-06 20:08:35 -0800
commit3d725349272ca0a5f443ec631374a24474e5a513 (patch)
tree6bf629aa4de920d88e9fc4c349282c6a95aca351
parent64809054c9083dbb66d632f47223d6c652d78d13 (diff)
Allows users to provide custom loss function for multclass and multilabel problems.
Change: 149376544
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py166
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py87
3 files changed, 196 insertions, 58 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 959b808d06..83904a77aa 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -602,6 +602,7 @@ py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:variables",
+ "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index b00c58b6e0..952cdeb5ec 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
import abc
+import inspect
+
import six
from tensorflow.contrib import framework as framework_lib
@@ -127,7 +129,8 @@ def _multi_class_head(n_classes,
enable_centered_bias=False,
head_name=None,
thresholds=None,
- metric_class_ids=None):
+ metric_class_ids=None,
+ loss_fn=None):
"""Creates a _Head for multi class single label classification.
The Head uses softmax cross entropy loss.
@@ -149,18 +152,25 @@ def _multi_class_head(n_classes,
metric_class_ids: List of class IDs for which we should report per-class
metrics. Must all be in the range `[0, n_classes)`. Invalid if
`n_classes` is 2.
+ loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as
+ parameter and returns a weighted scalar loss. `weights` should be
+ optional. See `tf.losses`
Returns:
An instance of _MultiClassHead.
Raises:
- ValueError: if `n_classes` is < 2, or `metric_class_ids` is provided when
- `n_classes` is 2.
+ ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when
+ `n_classes` is 2.
+ ValueError: If loss_fn does not have expected signature.
"""
if (n_classes is None) or (n_classes < 2):
raise ValueError("n_classes must be > 1 for classification: %s." %
n_classes)
+ if loss_fn:
+ _verify_loss_fn_args(loss_fn)
+ loss_fn = _wrap_custom_loss_fn(loss_fn) if loss_fn else None
if n_classes == 2:
if metric_class_ids:
raise ValueError("metric_class_ids invalid for n_classes==2.")
@@ -169,7 +179,8 @@ def _multi_class_head(n_classes,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
head_name=head_name,
- thresholds=thresholds)
+ thresholds=thresholds,
+ loss_fn=loss_fn)
return _MultiClassHead(
n_classes=n_classes,
@@ -178,7 +189,8 @@ def _multi_class_head(n_classes,
enable_centered_bias=enable_centered_bias,
head_name=head_name,
thresholds=thresholds,
- metric_class_ids=metric_class_ids)
+ metric_class_ids=metric_class_ids,
+ loss_fn=loss_fn)
def _binary_svm_head(
@@ -223,7 +235,8 @@ def _multi_label_head(n_classes,
enable_centered_bias=False,
head_name=None,
thresholds=None,
- metric_class_ids=None):
+ metric_class_ids=None,
+ loss_fn=None):
"""Creates a _Head for multi label classification.
The Head uses sigmoid cross entropy loss.
@@ -244,15 +257,22 @@ def _multi_label_head(n_classes,
thresholds: thresholds for eval metrics, defaults to [.5]
metric_class_ids: List of class IDs for which we should report per-class
metrics. Must all be in the range `[0, n_classes)`.
+ loss_fn: Optional function that takes (`labels`, `logits`, `weights`) as
+ parameter and returns a weighted scalar loss. `weights` should be
+ optional. See `tf.losses`
Returns:
An instance of _MultiLabelHead.
Raises:
- ValueError: if n_classes is < 2
+ ValueError: If n_classes is < 2
+ ValueError: If loss_fn does not have expected signature.
"""
if n_classes < 2:
raise ValueError("n_classes must be > 1 for classification.")
+ if loss_fn:
+ _verify_loss_fn_args(loss_fn)
+
return _MultiLabelHead(
n_classes=n_classes,
label_name=label_name,
@@ -260,7 +280,8 @@ def _multi_label_head(n_classes,
enable_centered_bias=enable_centered_bias,
head_name=head_name,
thresholds=thresholds,
- metric_class_ids=metric_class_ids)
+ metric_class_ids=metric_class_ids,
+ loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
def _multi_head(heads, loss_weights=None):
@@ -412,7 +433,7 @@ class _SingleHead(_Head):
# TODO(zakaria): use contrib losses.
-def _mean_squared_loss(labels, logits):
+def _mean_squared_loss(labels, logits, weights=None):
with ops.name_scope(None, "mean_squared_loss", (logits, labels)) as name:
logits = ops.convert_to_tensor(logits)
labels = ops.convert_to_tensor(labels)
@@ -423,10 +444,11 @@ def _mean_squared_loss(labels, logits):
if len(logits.get_shape()) == 1:
logits = array_ops.expand_dims(logits, dim=(1,))
logits.get_shape().assert_is_compatible_with(labels.get_shape())
- return math_ops.square(logits - math_ops.to_float(labels), name=name)
+ loss = math_ops.square(logits - math_ops.to_float(labels), name=name)
+ return _compute_weighted_loss(loss, weights)
-def _poisson_loss(labels, logits):
+def _poisson_loss(labels, logits, weights=None):
"""Computes poisson loss from logits."""
with ops.name_scope(None, "_poisson_loss", (logits, labels)) as name:
logits = ops.convert_to_tensor(logits)
@@ -438,8 +460,9 @@ def _poisson_loss(labels, logits):
if len(logits.get_shape()) == 1:
logits = array_ops.expand_dims(logits, dim=(1,))
logits.get_shape().assert_is_compatible_with(labels.get_shape())
- return nn.log_poisson_loss(labels, logits,
- compute_full_loss=True, name=name)
+ loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True,
+ name=name)
+ return _compute_weighted_loss(loss, weights)
def _logits(logits_input, logits, logits_dimension):
@@ -521,8 +544,7 @@ def _create_model_fn_ops(features,
eval_metric_ops = None
if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
weight_tensor = _weight_tensor(features, weight_column_name)
- loss, weighted_average_loss = _loss(
- loss_fn(labels, logits), weight_tensor)
+ loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor)
logging_ops.scalar_summary(
_summary_key(head_name, mkey.LOSS), weighted_average_loss)
@@ -530,7 +552,7 @@ def _create_model_fn_ops(features,
if train_op_fn is None:
raise ValueError("train_op_fn can not be None in TRAIN mode")
train_op = _train_op(loss, labels, train_op_fn, centered_bias,
- logits_dimension, loss_fn)
+ logits_dimension, loss_fn, weight_tensor)
eval_metric_ops = metrics_fn(
weighted_average_loss, predictions, labels, weight_tensor)
return model_fn.ModelFnOps(
@@ -641,7 +663,7 @@ class _RegressionHead(_SingleHead):
metrics_lib.streaming_mean(eval_loss)}
-def _log_loss_with_two_classes(labels, logits):
+def _log_loss_with_two_classes(labels, logits, weights=None):
with ops.name_scope(None, "log_loss_with_two_classes",
(logits, labels)) as name:
logits = ops.convert_to_tensor(logits)
@@ -650,8 +672,9 @@ def _log_loss_with_two_classes(labels, logits):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if len(labels.get_shape()) == 1:
labels = array_ops.expand_dims(labels, dim=(1,))
- return nn.sigmoid_cross_entropy_with_logits(
- labels=labels, logits=logits, name=name)
+ loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits,
+ name=name)
+ return _compute_weighted_loss(loss, weights)
def _one_class_to_two_class_logits(logits):
@@ -666,7 +689,7 @@ class _BinaryLogisticHead(_SingleHead):
weight_column_name=None,
enable_centered_bias=False,
head_name=None,
- loss_fn=_log_loss_with_two_classes,
+ loss_fn=None,
thresholds=None):
"""Base type for all single heads.
@@ -695,7 +718,7 @@ class _BinaryLogisticHead(_SingleHead):
weight_column_name=weight_column_name,
head_name=head_name)
self._thresholds = thresholds if thresholds else (.5,)
- self._loss_fn = loss_fn
+ self._loss_fn = loss_fn if loss_fn else _log_loss_with_two_classes
self._enable_centered_bias = enable_centered_bias
def create_model_fn_ops(self,
@@ -803,7 +826,7 @@ class _BinaryLogisticHead(_SingleHead):
return metrics
-def _softmax_cross_entropy_loss(labels, logits):
+def _softmax_cross_entropy_loss(labels, logits, weights=None):
with ops.name_scope(
None, "softmax_cross_entropy_loss", (logits, labels,)) as name:
labels = ops.convert_to_tensor(labels)
@@ -815,8 +838,9 @@ def _softmax_cross_entropy_loss(labels, logits):
# sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
if len(labels.get_shape()) == 2:
labels = array_ops.squeeze(labels, squeeze_dims=(1,))
- return nn.sparse_softmax_cross_entropy_with_logits(
+ loss = nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name=name)
+ return _compute_weighted_loss(loss, weights)
class _MultiClassHead(_SingleHead):
@@ -828,7 +852,7 @@ class _MultiClassHead(_SingleHead):
weight_column_name=None,
enable_centered_bias=False,
head_name=None,
- loss_fn=_softmax_cross_entropy_loss,
+ loss_fn=None,
thresholds=None,
metric_class_ids=None):
"""_Head for classification.
@@ -865,7 +889,7 @@ class _MultiClassHead(_SingleHead):
if (n_classes is None) or (n_classes <= 2):
raise ValueError("n_classes must be > 2: %s." % n_classes)
self._thresholds = thresholds if thresholds else (.5,)
- self._loss_fn = loss_fn
+ self._loss_fn = loss_fn if loss_fn else _softmax_cross_entropy_loss
self._enable_centered_bias = enable_centered_bias
self._metric_class_ids = tuple([] if metric_class_ids is None else
metric_class_ids)
@@ -1020,11 +1044,12 @@ class _BinarySvmHead(_SingleHead):
def __init__(self, label_name, weight_column_name, enable_centered_bias,
head_name, thresholds):
- def _loss_fn(labels, logits):
+ def _loss_fn(labels, logits, weights=None):
with ops.name_scope(None, "hinge_loss", (logits, labels)) as name:
with ops.control_dependencies((_assert_labels_rank(labels),)):
labels = array_ops.reshape(labels, shape=(-1, 1))
- return losses_lib.hinge_loss(logits, labels, scope=name)
+ loss = losses_lib.hinge_loss(logits=logits, labels=labels, scope=name)
+ return _compute_weighted_loss(loss, weights)
super(_BinarySvmHead, self).__init__(
problem_type=constants.ProblemType.LOGISTIC_REGRESSION,
@@ -1110,7 +1135,8 @@ class _MultiLabelHead(_SingleHead):
enable_centered_bias,
head_name,
thresholds,
- metric_class_ids=None):
+ metric_class_ids=None,
+ loss_fn=None):
super(_MultiLabelHead, self).__init__(
problem_type=constants.ProblemType.CLASSIFICATION,
@@ -1120,7 +1146,7 @@ class _MultiLabelHead(_SingleHead):
head_name=head_name)
self._thresholds = thresholds if thresholds else (.5,)
- self._loss_fn = _sigmoid_cross_entropy_loss
+ self._loss_fn = loss_fn if loss_fn else _sigmoid_cross_entropy_loss
self._enable_centered_bias = enable_centered_bias
self._metric_class_ids = tuple([] if metric_class_ids is None else
metric_class_ids)
@@ -1429,15 +1455,6 @@ class _MultiHead(_Head):
eval_metric_ops=metrics)
-def _weighted_loss(loss, weight):
- """Returns cumulative weighted loss as 1d `Tensor`."""
- with ops.name_scope(None, "weighted_loss", (loss, weight)) as name:
- return math_ops.multiply(
- array_ops.reshape(loss, shape=(-1,)),
- array_ops.reshape(weight, shape=(-1,)),
- name=name)
-
-
def _weight_tensor(features, weight_column_name):
"""Returns weights as 1d `Tensor`."""
if not weight_column_name:
@@ -1447,8 +1464,10 @@ def _weight_tensor(features, weight_column_name):
return math_ops.to_float(features[weight_column_name])
-def _loss(loss_unweighted, weight, name="loss"):
- """Returns a tuple of (loss, weighted_average_loss).
+# TODO(zakaria): This function is needed for backward compatibility and should
+# be removed when we migrate to core.
+def _compute_weighted_loss(loss_unweighted, weight, name="loss"):
+ """Returns a tuple of (loss_train, loss_report).
loss is used for gradient descent while weighted_average_loss is used for
summaries to be backward compatible.
@@ -1467,21 +1486,36 @@ def _loss(loss_unweighted, weight, name="loss"):
name: Optional name
Returns:
- A tuple of (loss, weighted_average_loss)
+ A tuple of losses. First one for training and the second one for reproting.
"""
with ops.name_scope(name, values=(loss_unweighted, weight)) as name_scope:
if weight is None:
loss = math_ops.reduce_mean(loss_unweighted, name=name_scope)
return loss, loss
- loss_weighted = _weighted_loss(loss_unweighted, weight)
+ with ops.name_scope(None, "weighted_loss",
+ (loss_unweighted, weight)) as name:
+ weighted_loss = math_ops.multiply(
+ array_ops.reshape(loss_unweighted, shape=(-1,)),
+ array_ops.reshape(weight, shape=(-1,)), name=name)
# TODO(ptucker): This might be wrong if weights are broadcast to loss shape.
# We should use tf.losses here.
- weighted_average_loss = math_ops.div(
- math_ops.reduce_sum(loss_weighted),
+ weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope)
+ weighted_loss_normalized = math_ops.div(
+ math_ops.reduce_sum(weighted_loss),
math_ops.to_float(math_ops.reduce_sum(weight)),
name="weighted_average_loss")
- loss = math_ops.reduce_mean(loss_weighted, name=name_scope)
- return loss, weighted_average_loss
+
+ return weighted_loss_mean, weighted_loss_normalized
+
+
+def _wrap_custom_loss_fn(loss_fn):
+ def _wrapper(labels, logits, weights=None):
+ if weights is None:
+ loss = loss_fn(labels, logits)
+ else:
+ loss = loss_fn(labels, logits, weights)
+ return loss, loss
+ return _wrapper
def _check_mode_valid(mode):
@@ -1491,6 +1525,26 @@ def _check_mode_valid(mode):
raise ValueError("mode=%s unrecognized." % str(mode))
+def _get_arguments(func):
+ """Returns a spec of given func."""
+ if hasattr(func, "__code__"):
+ # Regular function.
+ return inspect.getargspec(func)
+ elif hasattr(func, "__call__"):
+ # Callable object.
+ return _get_arguments(func.__call__)
+ elif hasattr(func, "func"):
+ # Partial function.
+ return _get_arguments(func.func)
+
+
+def _verify_loss_fn_args(loss_fn):
+ args = _get_arguments(loss_fn).args
+ for arg_name in ["labels", "logits", "weights"]:
+ if arg_name not in args:
+ raise ValueError("Argument %s not found in loss_fn." % arg_name)
+
+
def _centered_bias(logits_dimension, head_name=None):
"""Returns `logits`, optionally with centered bias applied.
@@ -1522,7 +1576,8 @@ def _centered_bias(logits_dimension, head_name=None):
return centered_bias
-def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn):
+def _centered_bias_step(centered_bias, logits_dimension, labels,
+ loss_fn, weights):
"""Creates and returns training op for centered bias."""
if (logits_dimension is None) or (logits_dimension < 1):
raise ValueError("Invalid logits_dimension %s." % logits_dimension)
@@ -1533,7 +1588,7 @@ def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn):
(batch_size, logits_dimension))
with ops.name_scope(None, "centered_bias", (labels, logits)):
centered_bias_loss = math_ops.reduce_mean(
- loss_fn(labels, logits), name="training_loss")
+ loss_fn(labels, logits, weights), name="training_loss")
# Learn central bias by an optimizer. 0.1 is a convervative lr for a
# single variable.
return training.AdagradOptimizer(0.1).minimize(
@@ -1544,16 +1599,12 @@ def _summary_key(head_name, val):
return "%s/%s" % (val, head_name) if head_name else val
-def _train_op(loss,
- labels,
- train_op_fn,
- centered_bias=None,
- logits_dimension=None,
- loss_fn=None):
+def _train_op(loss, labels, train_op_fn, centered_bias, logits_dimension,
+ loss_fn, weights):
"""Returns op for the training step."""
if centered_bias is not None:
centered_bias_step = _centered_bias_step(centered_bias, logits_dimension,
- labels, loss_fn)
+ labels, loss_fn, weights)
else:
centered_bias_step = None
with ops.name_scope(None, "train_op", (loss, labels)):
@@ -1563,12 +1614,13 @@ def _train_op(loss,
return train_op
-def _sigmoid_cross_entropy_loss(labels, logits):
+def _sigmoid_cross_entropy_loss(labels, logits, weights=None):
with ops.name_scope(None, "sigmoid_cross_entropy_loss",
(logits, labels)) as name:
# sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
- return nn.sigmoid_cross_entropy_with_logits(
+ loss = nn.sigmoid_cross_entropy_with_logits(
labels=math_ops.to_float(labels), logits=logits, name=name)
+ return _compute_weighted_loss(loss, weights)
def _float_weights_or_none(weights):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 52ac8992e5..faa3108caf 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import test
# pylint: enable=g-bad-todo,g-import-not-at-top
@@ -507,6 +508,26 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, .089985214,
self._expected_eval_metrics(2.69956), model_fn_ops)
+ def testMultiLabelWithCustomLoss(self):
+ n_classes = 3
+ head = head_lib._multi_label_head(
+ n_classes=n_classes,
+ weight_column_name="label_weight",
+ metric_class_ids=range(n_classes),
+ loss_fn=_sigmoid_cross_entropy)
+ with ops.Graph().as_default(), session.Session():
+ model_fn_ops = head.create_model_fn_ops(
+ features={"label_weight": .1},
+ labels=self._labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=self._logits)
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ _assert_metrics(self, 0.089985214,
+ self._expected_eval_metrics(0.089985214), model_fn_ops)
+
def testMultiLabelWithCenteredBias(self):
n_classes = 3
head = head_lib._multi_label_head(
@@ -779,13 +800,49 @@ class BinaryClassificationHeadTest(test.TestCase):
"auc": 0. / 1,
"labels/actual_label_mean": 1. / 1,
"labels/prediction_mean": .731059, # softmax
- # TODO(ptucker): Is this the correct eval loss, sum not average?
+ # eval loss is weighted loss divided by sum of weights.
"loss": expected_total_loss,
"precision/positive_threshold_0.500000_mean": 1. / 1,
"recall/positive_threshold_0.500000_mean": 1. / 1,
},
model_fn_ops)
+ def testBinaryClassificationWithCustomLoss(self):
+ head = head_lib._multi_class_head(
+ n_classes=2, weight_column_name="label_weight",
+ loss_fn=_sigmoid_cross_entropy)
+ with ops.Graph().as_default(), session.Session():
+ weights = ((.2,), (0.,))
+ model_fn_ops = head.create_model_fn_ops(
+ features={"label_weight": weights},
+ labels=self._labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=self._logits)
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ # expected_loss is (total_weighted_loss)/1 since htere is 1 nonzero
+ # weight.
+ expected_loss = 0.062652342
+ _assert_metrics(
+ self,
+ expected_loss,
+ {
+ "accuracy": 1. / 1,
+ "accuracy/baseline_label_mean": 1. / 1,
+ "accuracy/threshold_0.500000_mean": 1. / 1,
+ "auc": 0. / 1,
+ "labels/actual_label_mean": 1. / 1,
+ "labels/prediction_mean": .731059, # softmax
+ "loss": expected_loss,
+ "precision/positive_threshold_0.500000_mean": 1. / 1,
+ "recall/positive_threshold_0.500000_mean": 1. / 1,
+ },
+ model_fn_ops)
+
def testBinaryClassificationWithCenteredBias(self):
head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True)
with ops.Graph().as_default(), session.Session():
@@ -1010,6 +1067,30 @@ class MultiClassHeadTest(test.TestCase):
_assert_metrics(self, expected_loss * weight,
self._expected_eval_metrics(expected_loss), model_fn_ops)
+ def testMultiClassWithCustomLoss(self):
+ n_classes = 3
+ head = head_lib._multi_class_head(
+ n_classes=n_classes,
+ weight_column_name="label_weight",
+ metric_class_ids=range(n_classes),
+ loss_fn=losses_lib.sparse_softmax_cross_entropy)
+ with ops.Graph().as_default(), session.Session():
+ weight = .1
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ model_fn_ops = head.create_model_fn_ops(
+ features={"label_weight": weight},
+ labels=self._labels,
+ mode=model_fn.ModeKeys.TRAIN,
+ train_op_fn=head_lib.no_op_train_fn,
+ logits=self._logits)
+ self._assert_output_alternatives(model_fn_ops)
+ _assert_no_variables(self)
+ _assert_summary_tags(self, ["loss"])
+ expected_loss = 1.5514446 * weight
+ _assert_metrics(self, expected_loss,
+ self._expected_eval_metrics(expected_loss), model_fn_ops)
+
def testInvalidNClasses(self):
for n_classes in (None, -1, 0, 1):
with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
@@ -1370,5 +1451,9 @@ class MultiHeadTest(test.TestCase):
self.assertIn("accuracy/head2", metric_ops.keys())
+def _sigmoid_cross_entropy(labels, logits, weights):
+ return losses_lib.sigmoid_cross_entropy(labels, logits, weights)
+
+
if __name__ == "__main__":
test.main()