aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-15 12:50:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-15 12:54:12 -0800
commit972fa89023f8f27948321c388fa3f1f7857833c3 (patch)
tree21dd99f81ae7a6d2ed71b455f4cc59b514cc8096 /tensorflow/contrib/metrics
parent4b297b5434438175b016da05421e7ddd46c0f8ee (diff)
Add auc_with_confidence_intervals
This method computes the AUC and corresponding confidence intervals using an efficient algorithm. PiperOrigin-RevId: 185884228
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/BUILD1
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py291
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py199
4 files changed, 493 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index 9de664c822..e90c525113 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -43,6 +43,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:weights_broadcast_ops",
+ "//tensorflow/python/ops/distributions",
],
)
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index d3dce46bfb..de02dc8f45 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -16,6 +16,7 @@
See the @{$python/contrib.metrics} guide.
+@@auc_with_confidence_intervals
@@streaming_accuracy
@@streaming_mean
@@streaming_recall
@@ -83,6 +84,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
+from tensorflow.contrib.metrics.python.ops.metric_ops import auc_with_confidence_intervals
from tensorflow.contrib.metrics.python.ops.metric_ops import cohen_kappa
from tensorflow.contrib.metrics.python.ops.metric_ops import count
from tensorflow.contrib.metrics.python.ops.metric_ops import precision_recall_at_equal_thresholds
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 55946c128b..fc12bfd2b7 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.ops.distributions.normal import Normal
from tensorflow.python.util.deprecation import deprecated
# Epsilon constant used to represent extremely small quantity.
@@ -1196,6 +1197,295 @@ def streaming_dynamic_auc(labels,
return auc, update_op
+def _compute_placement_auc(labels, predictions, weights, alpha,
+ logit_transformation, is_valid):
+ """Computes the AUC and asymptotic normally distributed confidence interval.
+
+ The calculations are achieved using the fact that AUC = P(Y_1>Y_0) and the
+ concept of placement values for each labeled group, as presented by Delong and
+ Delong (1988). The actual algorithm used is a more computationally efficient
+ approach presented by Sun and Xu (2014). This could be slow for large batches,
+ but has the advantage of not having its results degrade depending on the
+ distribution of predictions.
+
+ Args:
+ labels: A `Tensor` of ground truth labels with the same shape as
+ `predictions` with values of 0 or 1 and type `int64`.
+ predictions: A 1-D `Tensor` of predictions whose values are `float64`.
+ weights: `Tensor` whose rank is either 0, or the same rank as `labels`.
+ alpha: Confidence interval level desired.
+ logit_transformation: A boolean value indicating whether the estimate should
+ be logit transformed prior to calculating the confidence interval. Doing
+ so enforces the restriction that the AUC should never be outside the
+ interval [0,1].
+ is_valid: A bool tensor describing whether the input is valid.
+
+ Returns:
+ A 1-D `Tensor` containing the area-under-curve, lower, and upper confidence
+ interval values.
+ """
+ # Disable the invalid-name checker so that we can capitalize the name.
+ # pylint: disable=invalid-name
+ AucData = collections_lib.namedtuple('AucData', ['auc', 'lower', 'upper'])
+ # pylint: enable=invalid-name
+
+ # If all the labels are the same or if number of observations are too few,
+ # AUC isn't well-defined
+ size = array_ops.size(predictions, out_type=dtypes.int32)
+
+ # Count the total number of positive and negative labels in the input.
+ total_0 = math_ops.reduce_sum(
+ math_ops.cast(1 - labels, weights.dtype) * weights)
+ total_1 = math_ops.reduce_sum(
+ math_ops.cast(labels, weights.dtype) * weights)
+
+ # Sort the predictions ascending, as well as
+ # (i) the corresponding labels and
+ # (ii) the corresponding weights.
+ ordered_predictions, indices = nn.top_k(predictions, k=size, sorted=True)
+ ordered_predictions = array_ops.reverse(
+ ordered_predictions, axis=array_ops.zeros(1, dtypes.int32))
+ indices = array_ops.reverse(indices, axis=array_ops.zeros(1, dtypes.int32))
+ ordered_labels = array_ops.gather(labels, indices)
+ ordered_weights = array_ops.gather(weights, indices)
+
+ # We now compute values required for computing placement values.
+
+ # We generate a list of indices (segmented_indices) of increasing order. An
+ # index is assigned for each unique prediction float value. Prediction
+ # values that are the same share the same index.
+ _, segmented_indices = array_ops.unique(ordered_predictions)
+
+ # We create 2 tensors of weights. weights_for_true is non-zero for true
+ # labels. weights_for_false is non-zero for false labels.
+ float_labels_for_true = math_ops.cast(ordered_labels, dtypes.float32)
+ float_labels_for_false = 1.0 - float_labels_for_true
+ weights_for_true = ordered_weights * float_labels_for_true
+ weights_for_false = ordered_weights * float_labels_for_false
+
+ # For each set of weights with the same segmented indices, we add up the
+ # weight values. Note that for each label, we deliberately rely on weights
+ # for the opposite label.
+ weight_totals_for_true = math_ops.segment_sum(weights_for_false,
+ segmented_indices)
+ weight_totals_for_false = math_ops.segment_sum(weights_for_true,
+ segmented_indices)
+
+ # These cumulative sums of weights importantly exclude the current weight
+ # sums.
+ cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true,
+ exclusive=True)
+ cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false,
+ exclusive=True)
+
+ # Compute placement values using the formula. Values with the same segmented
+ # indices and labels share the same placement values.
+ placements_for_true = (
+ (cum_weight_totals_for_true + weight_totals_for_true / 2.0) /
+ (math_ops.reduce_sum(weight_totals_for_true) + _EPSILON))
+ placements_for_false = (
+ (cum_weight_totals_for_false + weight_totals_for_false / 2.0) /
+ (math_ops.reduce_sum(weight_totals_for_false) + _EPSILON))
+
+ # We expand the tensors of placement values (for each label) so that their
+ # shapes match that of predictions.
+ placements_for_true = array_ops.gather(placements_for_true, segmented_indices)
+ placements_for_false = array_ops.gather(placements_for_false,
+ segmented_indices)
+
+ # Select placement values based on the label for each index.
+ placement_values = (
+ placements_for_true * float_labels_for_true +
+ placements_for_false * float_labels_for_false)
+
+ # Split placement values by labeled groups.
+ placement_values_0 = placement_values * math_ops.cast(
+ 1 - ordered_labels, weights.dtype)
+ weights_0 = ordered_weights * math_ops.cast(
+ 1 - ordered_labels, weights.dtype)
+ placement_values_1 = placement_values * math_ops.cast(
+ ordered_labels, weights.dtype)
+ weights_1 = ordered_weights * math_ops.cast(
+ ordered_labels, weights.dtype)
+
+ # Calculate AUC using placement values
+ auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) /
+ (total_0 + _EPSILON))
+ auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) /
+ (total_1 + _EPSILON))
+ auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0)
+
+ # Calculate variance and standard error using the placement values.
+ var_0 = (
+ math_ops.reduce_sum(
+ weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) /
+ (total_0 - 1. + _EPSILON))
+ var_1 = (
+ math_ops.reduce_sum(
+ weights_1 * math_ops.square(placement_values_1 - auc_1)) /
+ (total_1 - 1. + _EPSILON))
+ auc_std_err = math_ops.sqrt(
+ (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON)))
+
+ # Calculate asymptotic normal confidence intervals
+ std_norm_dist = Normal(loc=0., scale=1.)
+ z_value = std_norm_dist.quantile((1.0 - alpha) / 2.0)
+ if logit_transformation:
+ estimate = math_ops.log(auc / (1. - auc + _EPSILON))
+ std_err = auc_std_err / (auc * (1. - auc + _EPSILON))
+ transformed_auc_lower = estimate + (z_value * std_err)
+ transformed_auc_upper = estimate - (z_value * std_err)
+ def inverse_logit_transformation(x):
+ exp_negative = math_ops.exp(math_ops.negative(x))
+ return 1. / (1. + exp_negative + _EPSILON)
+
+ auc_lower = inverse_logit_transformation(transformed_auc_lower)
+ auc_upper = inverse_logit_transformation(transformed_auc_upper)
+ else:
+ estimate = auc
+ std_err = auc_std_err
+ auc_lower = estimate + (z_value * std_err)
+ auc_upper = estimate - (z_value * std_err)
+
+ ## If estimate is 1 or 0, no variance is present so CI = 1
+ ## n.b. This can be misleading, since number obs can just be too low.
+ lower = array_ops.where(
+ math_ops.logical_or(
+ math_ops.equal(auc, array_ops.ones_like(auc)),
+ math_ops.equal(auc, array_ops.zeros_like(auc))),
+ auc, auc_lower)
+ upper = array_ops.where(
+ math_ops.logical_or(
+ math_ops.equal(auc, array_ops.ones_like(auc)),
+ math_ops.equal(auc, array_ops.zeros_like(auc))),
+ auc, auc_upper)
+
+ # If all the labels are the same, AUC isn't well-defined (but raising an
+ # exception seems excessive) so we return 0, otherwise we finish computing.
+ trivial_value = array_ops.constant(0.0)
+
+ return AucData(*control_flow_ops.cond(
+ is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3))
+
+
+def auc_with_confidence_intervals(labels,
+ predictions,
+ weights=None,
+ alpha=0.95,
+ logit_transformation=True,
+ metrics_collections=(),
+ updates_collections=(),
+ name=None):
+ """Computes the AUC and asymptotic normally distributed confidence interval.
+
+ USAGE NOTE: this approach requires storing all of the predictions and labels
+ for a single evaluation in memory, so it may not be usable when the evaluation
+ batch size and/or the number of evaluation steps is very large.
+
+ Computes the area under the ROC curve and its confidence interval using
+ placement values. This has the advantage of being resilient to the
+ distribution of predictions by aggregating across batches, accumulating labels
+ and predictions and performing the final calculation using all of the
+ concatenated values.
+
+ Args:
+ labels: A `Tensor` of ground truth labels with the same shape as `labels`
+ and with values of 0 or 1 whose values are castable to `int64`.
+ predictions: A `Tensor` of predictions whose values are castable to
+ `float64`. Will be flattened into a 1-D `Tensor`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`.
+ alpha: Confidence interval level desired.
+ logit_transformation: A boolean value indicating whether the estimate should
+ be logit transformed prior to calculating the confidence interval. Doing
+ so enforces the restriction that the AUC should never be outside the
+ interval [0,1].
+ metrics_collections: An optional iterable of collections that `auc` should
+ be added to.
+ updates_collections: An optional iterable of collections that `update_op`
+ should be added to.
+ name: An optional name for the variable_scope that contains the metric
+ variables.
+
+ Returns:
+ auc: A 1-D `Tensor` containing the current area-under-curve, lower, and
+ upper confidence interval values.
+ update_op: An operation that concatenates the input labels and predictions
+ to the accumulated values.
+
+ Raises:
+ ValueError: If `labels`, `predictions`, and `weights` have mismatched shapes
+ or if `alpha` isn't in the range (0,1).
+ """
+ if not (alpha > 0 and alpha < 1):
+ raise ValueError('alpha must be between 0 and 1; currently %.02f' % alpha)
+
+ if weights is None:
+ weights = array_ops.ones_like(predictions)
+
+ with variable_scope.variable_scope(
+ name,
+ default_name='auc_with_confidence_intervals',
+ values=[labels, predictions, weights]):
+
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ predictions=predictions,
+ labels=labels,
+ weights=weights)
+
+ total_weight = math_ops.reduce_sum(weights)
+
+ weights = array_ops.reshape(weights, [-1])
+ predictions = array_ops.reshape(
+ math_ops.cast(predictions, dtypes.float64), [-1])
+ labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
+
+ with ops.control_dependencies([
+ check_ops.assert_greater_equal(
+ labels,
+ array_ops.zeros_like(labels, dtypes.int64),
+ message='labels must be 0 or 1, at least one is <0'),
+ check_ops.assert_less_equal(
+ labels,
+ array_ops.ones_like(labels, dtypes.int64),
+ message='labels must be 0 or 1, at least one is >1'),
+ ]):
+ preds_accum, update_preds = streaming_concat(
+ predictions, name='concat_preds')
+ labels_accum, update_labels = streaming_concat(labels,
+ name='concat_labels')
+ weights_accum, update_weights = streaming_concat(
+ weights, name='concat_weights')
+ update_op_for_valid_case = control_flow_ops.group(
+ update_labels, update_preds, update_weights)
+
+ # Only perform updates if this case is valid.
+ all_labels_positive_or_0 = math_ops.logical_and(
+ math_ops.equal(math_ops.reduce_min(labels), 0),
+ math_ops.equal(math_ops.reduce_max(labels), 1))
+ sums_of_weights_at_least_1 = math_ops.greater_equal(total_weight, 1.0)
+ is_valid = math_ops.logical_and(all_labels_positive_or_0,
+ sums_of_weights_at_least_1)
+
+ update_op = control_flow_ops.cond(
+ sums_of_weights_at_least_1,
+ lambda: update_op_for_valid_case, control_flow_ops.no_op)
+
+ auc = _compute_placement_auc(
+ labels_accum,
+ preds_accum,
+ weights_accum,
+ alpha=alpha,
+ logit_transformation=logit_transformation,
+ is_valid=is_valid)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, auc)
+ return auc, update_op
+
+
def precision_recall_at_equal_thresholds(labels,
predictions,
weights=None,
@@ -3430,6 +3720,7 @@ def cohen_kappa(labels,
__all__ = [
+ 'auc_with_confidence_intervals',
'aggregate_metric_map',
'aggregate_metrics',
'cohen_kappa',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index b4e365d10f..b387f26c01 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2128,6 +2128,205 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
+class AucWithConfidenceIntervalsTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def _testResultsEqual(self, expected_dict, gotten_result):
+ """Tests that 2 results (dicts) represent the same data.
+
+ Args:
+ expected_dict: A dictionary with keys that are the names of properties
+ of PrecisionRecallData and whose values are lists of floats.
+ gotten_result: A AucWithConfidenceIntervalData object.
+ """
+ gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()}
+ self.assertItemsEqual(
+ list(expected_dict.keys()), list(gotten_dict.keys()))
+
+ for key, expected_values in expected_dict.items():
+ self.assertAllClose(expected_values, gotten_dict[key])
+
+ def _testCase(self, predictions, labels, expected_result, weights=None):
+ """Performs a test given a certain scenario of labels, predictions, weights.
+
+ Args:
+ predictions: The predictions tensor. Of type float32.
+ labels: The labels tensor. Of type bool.
+ expected_result: The expected result (dict) that maps to tensors.
+ weights: Optional weights tensor.
+ """
+ with self.test_session() as sess:
+ predictions_tensor = constant_op.constant(
+ predictions, dtype=dtypes_lib.float32)
+ labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64)
+ weights_tensor = None
+ if weights:
+ weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32)
+ gotten_result, update_op = (
+ metric_ops.auc_with_confidence_intervals(
+ labels=labels_tensor,
+ predictions=predictions_tensor,
+ weights=weights_tensor))
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+
+ self._testResultsEqual(expected_result, gotten_result)
+
+ def testAucAllCorrect(self):
+ self._testCase(
+ predictions=[0., 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0],
+ labels=[0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
+ expected_result={
+ 'auc': 0.66666667,
+ 'lower': 0.27826795,
+ 'upper': 0.91208512,
+ })
+
+ def testAucUnorderedInput(self):
+ self._testCase(
+ predictions=[1.0, 0.6, 0., 0.3, 0.4, 0.2, 0.5, 0.3, 0.6, 0.8],
+ labels=[0, 1, 0, 1, 0, 0, 1, 0, 0, 1],
+ expected_result={
+ 'auc': 0.66666667,
+ 'lower': 0.27826795,
+ 'upper': 0.91208512,
+ })
+
+ def testAucWithWeights(self):
+ self._testCase(
+ predictions=[0., 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0],
+ labels=[0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
+ weights=[0.5, 0.6, 1.2, 1.5, 2.0, 2.0, 1.5, 1.2, 0.6, 0.5],
+ expected_result={
+ 'auc': 0.65151515,
+ 'lower': 0.28918604,
+ 'upper': 0.89573906,
+ })
+
+ def testAucEqualOne(self):
+ self._testCase(
+ predictions=[0, 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0],
+ labels=[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
+ expected_result={
+ 'auc': 1.0,
+ 'lower': 1.0,
+ 'upper': 1.0,
+ })
+
+ def testAucEqualZero(self):
+ self._testCase(
+ predictions=[0, 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0],
+ labels=[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
+ expected_result={
+ 'auc': 0.0,
+ 'lower': 0.0,
+ 'upper': 0.0,
+ })
+
+ def testNonZeroOnePredictions(self):
+ self._testCase(
+ predictions=[2.5, -2.5, .5, -.5, 1],
+ labels=[1, 0, 1, 0, 0],
+ expected_result={
+ 'auc': 0.83333333,
+ 'lower': 0.15229267,
+ 'upper': 0.99286517,
+ })
+
+ def testAllLabelsOnes(self):
+ self._testCase(
+ predictions=[1., 1., 1., 1., 1.],
+ labels=[1, 1, 1, 1, 1],
+ expected_result={
+ 'auc': 0.,
+ 'lower': 0.,
+ 'upper': 0.,
+ })
+
+ def testAllLabelsZeros(self):
+ self._testCase(
+ predictions=[0., 0., 0., 0., 0.],
+ labels=[0, 0, 0, 0, 0],
+ expected_result={
+ 'auc': 0.,
+ 'lower': 0.,
+ 'upper': 0.,
+ })
+
+ def testWeightSumLessThanOneAll(self):
+ self._testCase(
+ predictions=[1., 1., 0., 1., 0., 0.],
+ labels=[1, 1, 1, 0, 0, 0],
+ weights=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
+ expected_result={
+ 'auc': 0.,
+ 'lower': 0.,
+ 'upper': 0.,
+ })
+
+ def testWithMultipleUpdates(self):
+ batch_size = 50
+ num_batches = 100
+ labels = np.array([])
+ predictions = np.array([])
+ tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
+ tf_predictions = variables.Variable(
+ array_ops.ones(batch_size),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.float32)
+ auc, update_op = metrics.auc_with_confidence_intervals(tf_labels,
+ tf_predictions)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in xrange(num_batches):
+ new_labels = np.random.randint(0, 2, size=batch_size)
+ noise = np.random.normal(0.0, scale=0.2, size=batch_size)
+ new_predictions = 0.4 + 0.2 * new_labels + noise
+ labels = np.concatenate([labels, new_labels])
+ predictions = np.concatenate([predictions, new_predictions])
+ sess.run(tf_labels.assign(new_labels))
+ sess.run(tf_predictions.assign(new_predictions))
+ sess.run(update_op)
+ expected_auc = _np_auc(predictions, labels)
+ self.assertAllClose(expected_auc, auc.auc.eval())
+
+ def testExceptionOnFloatLabels(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
+ labels = constant_op.constant([0.7, 0, 1, 0, 1])
+ _, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ self.assertRaises(TypeError, sess.run(update_op))
+
+ def testExceptionOnGreaterThanOneLabel(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
+ labels = constant_op.constant([2, 1, 0, 1, 0])
+ _, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ '.*labels must be 0 or 1, at least one is >1.*'):
+ sess.run(update_op)
+
+ def testExceptionOnNegativeLabel(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
+ labels = constant_op.constant([1, 0, -1, 1, 0])
+ _, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
+ sess.run(variables.local_variables_initializer())
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ '.*labels must be 0 or 1, at least one is <0.*'):
+ sess.run(update_op)
+
+
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
def setUp(self):