diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-04 12:49:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-04 14:06:45 -0700 |
commit | 7e0b20510f25c6fb12ee8c055e32fb575f588abb (patch) | |
tree | 0b230227845f3b35d3ff75d7e335ef595c3743a5 | |
parent | fd69bb292af7f15cd364e36ead7f596a3c484b2c (diff) |
Add sparse_recall_at_top_k which takes top-k class indices instead of class logits.
Change: 155121560
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 84 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 275 | ||||
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 69 |
3 files changed, 427 insertions, 1 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index d57203c042..727cdd9597 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, name=name_scope) +def sparse_recall_at_top_k(labels, + top_k_predictions, + class_id=None, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes recall@k of top-k predictions with respect to sparse labels. + + If `class_id` is specified, we calculate recall by considering only the + entries in the batch for which `class_id` is in the label, and computing + the fraction of them for which `class_id` is in the top-k `predictions`. + If `class_id` is not specified, we'll calculate recall as how often on + average a class among the labels of a batch entry is in the top-k + `predictions`. + + `sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>` + and `false_negative_at_<k>`, that are used to compute the recall_at_k + frequency. This frequency is ultimately returned as `recall_at_<k>`: an + idempotent operation that simply divides `true_positive_at_<k>` by total + (`true_positive_at_<k>` + `false_negative_at_<k>`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the + true positives and false negatives weighted by `weights`. Then `update_op` + increments `true_positive_at_<k>` and `false_negative_at_<k>` using these + values. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: `int64` `Tensor` or `SparseTensor` with shape + [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of + target classes for the associated prediction. Commonly, N=1 and `labels` + has shape [batch_size, num_labels]. [D1, ... DN] must match + `top_k_predictions`. Values should be in range [0, num_classes), where + num_classes is the last dimension of `predictions`. Values outside this + range always count towards `false_negative_at_<k>`. + top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where + N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. + The final dimension contains the indices of top-k labels. [D1, ... DN] + must match `labels`. + class_id: Integer class ID for which we want binary metrics. This should be + in range [0, num_classes), where num_classes is the last dimension of + `predictions`. If class_id is outside this range, the method returns NAN. + weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of + `labels`. If the latter, it must be broadcastable to `labels` (i.e., all + dimensions must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that values should + be added to. + updates_collections: An optional list of collections that updates should + be added to. + name: Name of new update operation, and namespace for other dependent ops. + + Returns: + recall: Scalar `float64` `Tensor` with the value of `true_positives` divided + by the sum of `true_positives` and `false_negatives`. + update_op: `Operation` that increments `true_positives` and + `false_negatives` variables appropriately, and whose value matches + `recall`. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match + `predictions`, or if either `metrics_collections` or `updates_collections` + are not a list or tuple. + """ + default_name = _at_k_name('recall', class_id=class_id) + with ops.name_scope(name, default_name, (top_k_predictions, labels, + weights)) as name_scope: + return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access + labels=labels, + predictions_idx=top_k_predictions, + class_id=class_id, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name_scope) + + def streaming_sparse_average_precision_at_k(predictions, labels, k, @@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights): __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', + 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', 'streaming_false_negatives', @@ -2310,7 +2392,9 @@ __all__ = [ 'streaming_root_mean_squared_error', 'streaming_sensitivity_at_specificity', 'streaming_sparse_average_precision_at_k', + 'streaming_sparse_average_precision_at_top_k', 'streaming_sparse_precision_at_k', + 'streaming_sparse_precision_at_top_k', 'streaming_sparse_recall_at_k', 'streaming_specificity_at_sensitivity', 'streaming_true_negatives', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index b960e1310e..f42e974e23 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase): self.assertEqual(expected, update.eval()) self.assertEqual(expected, metric.eval()) + def _test_sparse_recall_at_top_k(self, + labels, + top_k_predictions, + expected, + class_id=None, + weights=None): + with ops.Graph().as_default() as g, self.test_session(g): + if weights is not None: + weights = constant_op.constant(weights, dtypes_lib.float32) + metric, update = metric_ops.sparse_recall_at_top_k( + labels=labels, + top_k_predictions=constant_op.constant(top_k_predictions, + dtypes_lib.int32), + class_id=class_id, + weights=weights) + + # Fails without initialized vars. + self.assertRaises(errors_impl.OpError, metric.eval) + self.assertRaises(errors_impl.OpError, update.eval) + variables.variables_initializer(variables.local_variables()).run() + + # Run per-step op and assert expected values. + if math.isnan(expected): + self.assertTrue(math.isnan(update.eval())) + self.assertTrue(math.isnan(metric.eval())) + else: + self.assertEqual(expected, update.eval()) + self.assertEqual(expected, metric.eval()) + def test_one_label_at_k1_nan(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (-1, 0, 1, 4): self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_one_label_at_k1_no_predictions(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 0 predictions. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.0, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0, class_id=2) def test_one_label_at_k1(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase): # Class 3: 1 label, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 1, class_id=3) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, class_id=3) # All classes: 2 labels, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2) def test_one_label_at_k1_weighted(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase): # Class 3: 1 label, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=3, weights=(0.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=3, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2, class_id=3, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=3.0 / 3, class_id=3, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.3 / 0.3, class_id=3, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.6 / 0.6, class_id=3, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) # All classes: 2 labels, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=(0.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) def test_three_labels_at_k5_nan(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (0, 3, 4, 6, 9, 10): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_three_labels_at_k5_no_predictions(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase): # Class 8: 1 label, no predictions. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0 / 1, class_id=8) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, class_id=8) def test_three_labels_at_k5(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=2.0 / 2, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 2, class_id=2) # Class 5: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=1.0 / 1, class_id=5) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, class_id=5) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0 / 1, class_id=7) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, class_id=7) # All classes: 6 labels, 3 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=3.0 / 6) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 6) def test_three_labels_at_k5_some_out_of_range(self): """Tests that labels outside the [0, n_classes) count in denominator.""" predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sp_labels = sparse_tensor.SparseTensorValue( indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1, 3]], @@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=2.0 / 2, class_id=2) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=2.0 / 2, + class_id=2) # Class 5: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=1.0 / 1, class_id=5) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=1.0 / 1, + class_id=5) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=0.0 / 1, class_id=7) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=0.0 / 1, + class_id=7) # All classes: 8 labels, 3 correct. self._test_streaming_sparse_recall_at_k( predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8) + self._test_sparse_recall_at_top_k( + sp_labels, top_k_predictions, expected=3.0 / 8) def test_3d_nan(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] sparse_labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]]) @@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (0, 3, 4, 6, 9, 10): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_3d_no_predictions(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] sparse_labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (1, 8): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0, class_id=class_id) def test_3d(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 4 labels, all correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=4.0 / 4, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=4.0 / 4, class_id=2) # Class 5: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=2.0 / 2, class_id=5) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 2, class_id=5) # Class 7: 2 labels, 1 incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=1.0 / 2, class_id=7) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, class_id=7) # All classes: 12 labels, 7 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=7.0 / 12) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=7.0 / 12) def test_3d_ignore_all(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=class_id, weights=[[0], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=class_id, + weights=[[0], [0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=class_id, weights=[[0, 0], [0, 0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=class_id, + weights=[[0, 0], [0, 0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, weights=[[0], [0]]) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=[[0], [0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]]) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]]) def test_3d_ignore_some(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2.0, class_id=2, weights=[[1], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2.0, + class_id=2, + weights=[[1], [0]]) # Class 2: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( @@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2.0, class_id=2, weights=[[0], [1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2.0, + class_id=2, + weights=[[0], [1]]) # Class 7: 1 label, correct. self._test_streaming_sparse_recall_at_k( @@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1.0, class_id=7, weights=[[0], [1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1.0, + class_id=7, + weights=[[0], [1]]) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.0 / 1.0, class_id=7, weights=[[1], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.0 / 1.0, + class_id=7, + weights=[[1], [0]]) # Class 7: 2 labels, 1 correct. self._test_streaming_sparse_recall_at_k( @@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 2.0, class_id=7, weights=[[1, 0], [1, 0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 2.0, + class_id=7, + weights=[[1, 0], [1, 0]]) # Class 7: No labels. self._test_streaming_sparse_recall_at_k( @@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=7, weights=[[0, 1], [0, 1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=7, + weights=[[0, 1], [0, 1]]) def test_sparse_tensor_value(self): predictions = [[0.1, 0.3, 0.2, 0.4], diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 4dc8e702ca..28ed3af9d7 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -1924,7 +1924,74 @@ def recall_at_k(labels, labels = _maybe_expand_labels(labels, predictions) _, top_k_idx = nn.top_k(predictions, k) - top_k_idx = math_ops.to_int64(top_k_idx) + return _sparse_recall_at_top_k( + labels=labels, + predictions_idx=top_k_idx, + k=k, + class_id=class_id, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=scope) + + +def _sparse_recall_at_top_k(labels, + predictions_idx, + k=None, + class_id=None, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes recall@k of top-k predictions with respect to sparse labels. + + Differs from `recall_at_k` in that predictions must be in the form of top `k` + class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` + for more details. + + Args: + labels: `int64` `Tensor` or `SparseTensor` with shape + [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies + num_labels=1. N >= 1 and num_labels is the number of target classes for + the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values + should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range always count + towards `false_negative_at_<k>`. + predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. + Commonly, N=1 and predictions has shape [batch size, k]. The final + dimension contains the top `k` predicted class indices. [D1, ... DN] must + match `labels`. + k: Integer, k for @k metric. + class_id: Integer class ID for which we want binary metrics. This should be + in range [0, num_classes), where num_classes is the last dimension of + `predictions`. If class_id is outside this range, the method returns NAN. + weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of + `labels`. If the latter, it must be broadcastable to `labels` (i.e., all + dimensions must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that values should + be added to. + updates_collections: An optional list of collections that updates should + be added to. + name: Name of new update operation, and namespace for other dependent ops. + + Returns: + recall: Scalar `float64` `Tensor` with the value of `true_positives` divided + by the sum of `true_positives` and `false_negatives`. + update_op: `Operation` that increments `true_positives` and + `false_negatives` variables appropriately, and whose value matches + `recall`. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match + `predictions`, or if either `metrics_collections` or `updates_collections` + are not a list or tuple. + """ + with ops.name_scope(name, + _at_k_name('recall', k, class_id=class_id), + (predictions_idx, labels, weights)) as scope: + top_k_idx = math_ops.to_int64(predictions_idx) tp, tp_update = _streaming_sparse_true_positive_at_k( predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, weights=weights) |