aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-04 12:49:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-04 14:06:45 -0700
commit7e0b20510f25c6fb12ee8c055e32fb575f588abb (patch)
tree0b230227845f3b35d3ff75d7e335ef595c3743a5
parentfd69bb292af7f15cd364e36ead7f596a3c484b2c (diff)
Add sparse_recall_at_top_k which takes top-k class indices instead of class logits.
Change: 155121560
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py84
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py275
-rw-r--r--tensorflow/python/ops/metrics_impl.py69
3 files changed, 427 insertions, 1 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index d57203c042..727cdd9597 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
name=name_scope)
+def sparse_recall_at_top_k(labels,
+ top_k_predictions,
+ class_id=None,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes recall@k of top-k predictions with respect to sparse labels.
+
+ If `class_id` is specified, we calculate recall by considering only the
+ entries in the batch for which `class_id` is in the label, and computing
+ the fraction of them for which `class_id` is in the top-k `predictions`.
+ If `class_id` is not specified, we'll calculate recall as how often on
+ average a class among the labels of a batch entry is in the top-k
+ `predictions`.
+
+ `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
+ and `false_negative_at_<k>`, that are used to compute the recall_at_k
+ frequency. This frequency is ultimately returned as `recall_at_<k>`: an
+ idempotent operation that simply divides `true_positive_at_<k>` by total
+ (`true_positive_at_<k>` + `false_negative_at_<k>`).
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
+ true positives and false negatives weighted by `weights`. Then `update_op`
+ increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
+ values.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ labels: `int64` `Tensor` or `SparseTensor` with shape
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
+ target classes for the associated prediction. Commonly, N=1 and `labels`
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
+ `top_k_predictions`. Values should be in range [0, num_classes), where
+ num_classes is the last dimension of `predictions`. Values outside this
+ range always count towards `false_negative_at_<k>`.
+ top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
+ N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
+ The final dimension contains the indices of top-k labels. [D1, ... DN]
+ must match `labels`.
+ class_id: Integer class ID for which we want binary metrics. This should be
+ in range [0, num_classes), where num_classes is the last dimension of
+ `predictions`. If class_id is outside this range, the method returns NAN.
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
+ dimensions must be either `1`, or the same as the corresponding `labels`
+ dimension).
+ metrics_collections: An optional list of collections that values should
+ be added to.
+ updates_collections: An optional list of collections that updates should
+ be added to.
+ name: Name of new update operation, and namespace for other dependent ops.
+
+ Returns:
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
+ by the sum of `true_positives` and `false_negatives`.
+ update_op: `Operation` that increments `true_positives` and
+ `false_negatives` variables appropriately, and whose value matches
+ `recall`.
+
+ Raises:
+ ValueError: If `weights` is not `None` and its shape doesn't match
+ `predictions`, or if either `metrics_collections` or `updates_collections`
+ are not a list or tuple.
+ """
+ default_name = _at_k_name('recall', class_id=class_id)
+ with ops.name_scope(name, default_name, (top_k_predictions, labels,
+ weights)) as name_scope:
+ return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access
+ labels=labels,
+ predictions_idx=top_k_predictions,
+ class_id=class_id,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=name_scope)
+
+
def streaming_sparse_average_precision_at_k(predictions,
labels,
k,
@@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
+ 'sparse_recall_at_top_k',
'streaming_accuracy',
'streaming_auc',
'streaming_false_negatives',
@@ -2310,7 +2392,9 @@ __all__ = [
'streaming_root_mean_squared_error',
'streaming_sensitivity_at_specificity',
'streaming_sparse_average_precision_at_k',
+ 'streaming_sparse_average_precision_at_top_k',
'streaming_sparse_precision_at_k',
+ 'streaming_sparse_precision_at_top_k',
'streaming_sparse_recall_at_k',
'streaming_specificity_at_sensitivity',
'streaming_true_negatives',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index b960e1310e..f42e974e23 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase):
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
+ def _test_sparse_recall_at_top_k(self,
+ labels,
+ top_k_predictions,
+ expected,
+ class_id=None,
+ weights=None):
+ with ops.Graph().as_default() as g, self.test_session(g):
+ if weights is not None:
+ weights = constant_op.constant(weights, dtypes_lib.float32)
+ metric, update = metric_ops.sparse_recall_at_top_k(
+ labels=labels,
+ top_k_predictions=constant_op.constant(top_k_predictions,
+ dtypes_lib.int32),
+ class_id=class_id,
+ weights=weights)
+
+ # Fails without initialized vars.
+ self.assertRaises(errors_impl.OpError, metric.eval)
+ self.assertRaises(errors_impl.OpError, update.eval)
+ variables.variables_initializer(variables.local_variables()).run()
+
+ # Run per-step op and assert expected values.
+ if math.isnan(expected):
+ self.assertTrue(math.isnan(update.eval()))
+ self.assertTrue(math.isnan(metric.eval()))
+ else:
+ self.assertEqual(expected, update.eval())
+ self.assertEqual(expected, metric.eval())
+
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (-1, 0, 1, 4):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_one_label_at_k1_no_predictions(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 0 predictions.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0, class_id=2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0, class_id=2)
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, class_id=3)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, class_id=3)
# All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2)
def test_one_label_at_k1_weighted(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(2.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=3,
weights=(0.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 0.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=3,
weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 0.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2,
class_id=3,
weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2,
+ class_id=3,
+ weights=(2.0, 3.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=3.0 / 3,
class_id=3,
weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=3.0 / 3,
+ class_id=3,
+ weights=(3.0, 2.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.3 / 0.3,
class_id=3,
weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.3 / 0.3,
+ class_id=3,
+ weights=(0.3, 0.6))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.6 / 0.6,
class_id=3,
weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.6 / 0.6,
+ class_id=3,
+ weights=(0.6, 0.3))
# All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=(0.0,))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
+
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
def test_three_labels_at_k5_nan(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_three_labels_at_k5_no_predictions(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 8: 1 label, no predictions.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=8)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, class_id=8)
def test_three_labels_at_k5(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 1, class_id=5)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=7)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, class_id=7)
# All classes: 6 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=3.0 / 6)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=3.0 / 6)
def test_three_labels_at_k5_some_out_of_range(self):
"""Tests that labels outside the [0, n_classes) count in denominator."""
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sp_labels = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
[1, 3]],
@@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=2.0 / 2,
class_id=2)
+ self._test_sparse_recall_at_top_k(
+ sp_labels,
+ top_k_predictions,
+ expected=2.0 / 2,
+ class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=1.0 / 1,
class_id=5)
+ self._test_sparse_recall_at_top_k(
+ sp_labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=0.0 / 1,
class_id=7)
+ self._test_sparse_recall_at_top_k(
+ sp_labels,
+ top_k_predictions,
+ expected=0.0 / 1,
+ class_id=7)
# All classes: 8 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
+ self._test_sparse_recall_at_top_k(
+ sp_labels, top_k_predictions, expected=3.0 / 8)
def test_3d_nan(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
[[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
@@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_3d_no_predictions(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (1, 8):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0, class_id=class_id)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0, class_id=class_id)
def test_3d(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 4 labels, all correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=4.0 / 4, class_id=2)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=4.0 / 4, class_id=2)
# Class 5: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=5)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 2, class_id=5)
# Class 7: 2 labels, 1 incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 2, class_id=7)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, class_id=7)
# All classes: 12 labels, 7 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=7.0 / 12)
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=7.0 / 12)
def test_3d_ignore_all(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=class_id,
weights=[[0], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=class_id,
+ weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=class_id,
weights=[[0, 0], [0, 0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=class_id,
+ weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]])
def test_3d_ignore_some(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0,
class_id=2,
weights=[[1], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2.0,
+ class_id=2,
+ weights=[[1], [0]])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
@@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0,
class_id=2,
weights=[[0], [1]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2.0,
+ class_id=2,
+ weights=[[0], [1]])
# Class 7: 1 label, correct.
self._test_streaming_sparse_recall_at_k(
@@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1.0,
class_id=7,
weights=[[0], [1]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1.0,
+ class_id=7,
+ weights=[[0], [1]])
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.0 / 1.0,
class_id=7,
weights=[[1], [0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.0 / 1.0,
+ class_id=7,
+ weights=[[1], [0]])
# Class 7: 2 labels, 1 correct.
self._test_streaming_sparse_recall_at_k(
@@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 2.0,
class_id=7,
weights=[[1, 0], [1, 0]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 2.0,
+ class_id=7,
+ weights=[[1, 0], [1, 0]])
# Class 7: No labels.
self._test_streaming_sparse_recall_at_k(
@@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=7,
weights=[[0, 1], [0, 1]])
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=7,
+ weights=[[0, 1], [0, 1]])
def test_sparse_tensor_value(self):
predictions = [[0.1, 0.3, 0.2, 0.4],
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 4dc8e702ca..28ed3af9d7 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -1924,7 +1924,74 @@ def recall_at_k(labels,
labels = _maybe_expand_labels(labels, predictions)
_, top_k_idx = nn.top_k(predictions, k)
- top_k_idx = math_ops.to_int64(top_k_idx)
+ return _sparse_recall_at_top_k(
+ labels=labels,
+ predictions_idx=top_k_idx,
+ k=k,
+ class_id=class_id,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=scope)
+
+
+def _sparse_recall_at_top_k(labels,
+ predictions_idx,
+ k=None,
+ class_id=None,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes recall@k of top-k predictions with respect to sparse labels.
+
+ Differs from `recall_at_k` in that predictions must be in the form of top `k`
+ class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
+ for more details.
+
+ Args:
+ labels: `int64` `Tensor` or `SparseTensor` with shape
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
+ the associated prediction. Commonly, N=1 and `labels` has shape
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
+ should be in range [0, num_classes), where num_classes is the last
+ dimension of `predictions`. Values outside this range always count
+ towards `false_negative_at_<k>`.
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
+ Commonly, N=1 and predictions has shape [batch size, k]. The final
+ dimension contains the top `k` predicted class indices. [D1, ... DN] must
+ match `labels`.
+ k: Integer, k for @k metric.
+ class_id: Integer class ID for which we want binary metrics. This should be
+ in range [0, num_classes), where num_classes is the last dimension of
+ `predictions`. If class_id is outside this range, the method returns NAN.
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
+ dimensions must be either `1`, or the same as the corresponding `labels`
+ dimension).
+ metrics_collections: An optional list of collections that values should
+ be added to.
+ updates_collections: An optional list of collections that updates should
+ be added to.
+ name: Name of new update operation, and namespace for other dependent ops.
+
+ Returns:
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
+ by the sum of `true_positives` and `false_negatives`.
+ update_op: `Operation` that increments `true_positives` and
+ `false_negatives` variables appropriately, and whose value matches
+ `recall`.
+
+ Raises:
+ ValueError: If `weights` is not `None` and its shape doesn't match
+ `predictions`, or if either `metrics_collections` or `updates_collections`
+ are not a list or tuple.
+ """
+ with ops.name_scope(name,
+ _at_k_name('recall', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
+ top_k_idx = math_ops.to_int64(predictions_idx)
tp, tp_update = _streaming_sparse_true_positive_at_k(
predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
weights=weights)