aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-14 12:11:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 13:19:52 -0700
commit33970acb0a94fbf3f3ef0f56733fd6d6a4af4a01 (patch)
tree303de2d28360b07b4bb8fc5a6c48d1895a810454
parent3a8ec777219002edf05264730f06760d1c346d6f (diff)
Add streaming_sparse_precision_at_top_k in metric_ops.
Change: 136189470
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py212
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py185
3 files changed, 374 insertions, 25 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 9926d98046..a0b7b1ccff 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -120,6 +120,7 @@ time.
@@streaming_sensitivity_at_specificity
@@streaming_sparse_average_precision_at_k
@@streaming_sparse_precision_at_k
+@@streaming_sparse_precision_at_top_k
@@streaming_sparse_recall_at_k
@@streaming_specificity_at_sensitivity
@@streaming_concat
@@ -172,6 +173,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_root_mean
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sensitivity_at_specificity
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_average_precision_at_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_precision_at_k
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_precision_at_top_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_recall_at_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_specificity_at_sensitivity
from tensorflow.contrib.metrics.python.ops.set_ops import set_difference
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 088fa04516..98a00adf2d 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1222,8 +1222,11 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds,
return recall, update_op
-def _at_k_name(name, k, class_id=None):
- name = '%s_at_%d' % (name, k)
+def _at_k_name(name, k=None, class_id=None):
+ if k is not None:
+ name = '%s_at_%d' % (name, k)
+ else:
+ name = '%s_at_k' % (name)
if class_id is not None:
name = '%s_class%d' % (name, class_id)
return name
@@ -1388,6 +1391,79 @@ def streaming_sparse_recall_at_k(predictions,
return metric, update
+def _streaming_sparse_precision_at_k(top_k_idx,
+ labels,
+ k=None,
+ class_id=None,
+ ignore_mask=None,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes precision@k of the top-k indices with respect to sparse labels.
+
+ This method contains the code shared by streaming_sparse_precision_at_k and
+ streaming_sparse_precision_at_top_k. Refer to those methods for more details.
+
+ Args:
+ top_k_idx: Integer `Tensor` with shape [D1, ... DN, k] where
+ N >= 1. Commonly, N=1 and top_k_idx has shape [batch size, k].
+ The final dimension contains the indices of top-k labels. [D1, ... DN]
+ must match `labels`.
+ labels: `int64` `Tensor` or `SparseTensor` with shape
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
+ target classes for the associated prediction. Commonly, N=1 and `labels`
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
+ `predictions_idx`. Values should be in range [0, num_classes), where
+ num_classes is the last dimension of `predictions`. Values outside this
+ range are ignored.
+ k: Integer, k for @k metric or `None`. Only used for default op name.
+ class_id: Integer class ID for which we want binary metrics. This should be
+ in range [0, num_classes), where num_classes is the last dimension of
+ `predictions`. If `class_id` is outside this range, the method returns
+ NAN.
+ ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
+ the the first [D1, ... DN] dimensions of `predictions` and `labels`.
+ weights: An optional `Tensor` whose shape is broadcastable to the the first
+ [D1, ... DN] dimensions of `predictions` and `labels`.
+ metrics_collections: An optional list of collections that values should
+ be added to.
+ updates_collections: An optional list of collections that updates should
+ be added to.
+ name: Name of the metric and of the enclosing scope.
+
+ Returns:
+ precision: Scalar `float64` `Tensor` with the value of `true_positives`
+ divided by the sum of `true_positives` and `false_positives`.
+ update_op: `Operation` that increments `true_positives` and
+ `false_positives` variables appropriately, and whose value matches
+ `precision`.
+
+ Raises:
+ ValueError: If `ignore_mask` is not `None` and its shape doesn't match
+ `predictions`, or if `weights` is not `None` and its shape doesn't match
+ `predictions`, or if either `metrics_collections` or `updates_collections`
+ are not a list or tuple.
+ """
+ top_k_idx = math_ops.to_int64(top_k_idx)
+ weights = _mask_weights(ignore_mask, weights)
+ tp, tp_update = _streaming_sparse_true_positive_at_k(
+ predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ weights=weights)
+ fp, fp_update = _streaming_sparse_false_positive_at_k(
+ predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ weights=weights)
+
+ metric = math_ops.div(tp, math_ops.add(tp, fp), name=name)
+ update = math_ops.div(
+ tp_update, math_ops.add(tp_update, fp_update), name='update')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update)
+ return metric, update
+
+
# TODO(ptucker): Validate range of values in labels?
@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
def streaming_sparse_precision_at_k(predictions,
@@ -1443,7 +1519,8 @@ def streaming_sparse_precision_at_k(predictions,
k: Integer, k for @k metric.
class_id: Integer class ID for which we want binary metrics. This should be
in range [0, num_classes], where num_classes is the last dimension of
- `predictions`. If class_id is outside this range, the method returns NAN.
+ `predictions`. If `class_id` is outside this range, the method returns
+ NAN.
ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
the the first [D1, ... DN] dimensions of `predictions` and `labels`.
weights: An optional `Tensor` whose shape is broadcastable to the the first
@@ -1468,25 +1545,116 @@ def streaming_sparse_precision_at_k(predictions,
are not a list or tuple.
"""
default_name = _at_k_name('precision', k, class_id=class_id)
- with ops.name_scope(name, default_name, (predictions, labels)) as scope:
+ with ops.name_scope(name, default_name,
+ (predictions, labels, ignore_mask, weights)) as scope:
_, top_k_idx = nn.top_k(predictions, k)
- top_k_idx = math_ops.to_int64(top_k_idx)
- weights = _mask_weights(ignore_mask, weights)
- tp, tp_update = _streaming_sparse_true_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
- weights=weights)
- fp, fp_update = _streaming_sparse_false_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
- weights=weights)
+ return _streaming_sparse_precision_at_k(
+ top_k_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
+ ignore_mask=ignore_mask,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=scope)
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
- update = math_ops.div(
- tp_update, math_ops.add(tp_update, fp_update), name='update')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
- if updates_collections:
- ops.add_to_collections(updates_collections, update)
- return metric, update
+
+# TODO(ptucker): Validate range of values in labels?
+@deprecated_args(IGNORE_MASK_DATE, IGNORE_MASK_INSTRUCTIONS, 'ignore_mask')
+def streaming_sparse_precision_at_top_k(top_k_predictions,
+ labels,
+ class_id=None,
+ ignore_mask=None,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes precision@k of top-k predictions with respect to sparse labels.
+
+ If `class_id` is specified, we calculate precision by considering only the
+ entries in the batch for which `class_id` is in the top-k highest
+ `predictions`, and computing the fraction of them for which `class_id` is
+ indeed a correct label.
+ If `class_id` is not specified, we'll calculate precision as how often on
+ average a class among the top-k classes with the highest predicted values
+ of a batch entry is correct and can be found in the label for that entry.
+
+ `streaming_sparse_precision_at_top_k` creates two local variables,
+ `true_positive_at_k` and `false_positive_at_k`, that are used to compute
+ the precision@k frequency. This frequency is ultimately returned as
+ `precision_at_k`: an idempotent operation that simply divides
+ `true_positive_at_k` by total (`true_positive_at_k` + `false_positive_at_k`).
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `precision_at_k`. Internally, set operations applied to `top_k_predictions`
+ and `labels` calculate the true positives and false positives weighted by
+ `weights`. Then `update_op` increments `true_positive_at_k` and
+ `false_positive_at_k` using these values.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+ Alternatively, if `ignore_mask` is not `None`, then mask values where
+ `ignore_mask` is `True`.
+
+ Args:
+ top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
+ N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
+ The final dimension contains the indices of top-k labels. [D1, ... DN]
+ must match `labels`.
+ labels: `int64` `Tensor` or `SparseTensor` with shape
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
+ target classes for the associated prediction. Commonly, N=1 and `labels`
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
+ `top_k_predictions`. Values should be in range [0, num_classes), where
+ num_classes is the last dimension of `predictions`. Values outside this
+ range are ignored.
+ class_id: Integer class ID for which we want binary metrics. This should be
+ in range [0, num_classes), where num_classes is the last dimension of
+ `predictions`. If `class_id` is outside this range, the method returns
+ NAN.
+ ignore_mask: An optional, `bool` `Tensor` whose shape is broadcastable to
+ the the first [D1, ... DN] dimensions of `predictions` and `labels`.
+ weights: An optional `Tensor` whose shape is broadcastable to the the first
+ [D1, ... DN] dimensions of `predictions` and `labels`.
+ metrics_collections: An optional list of collections that values should
+ be added to.
+ updates_collections: An optional list of collections that updates should
+ be added to.
+ name: Name of new update operation, and namespace for other dependent ops.
+
+ Returns:
+ precision: Scalar `float64` `Tensor` with the value of `true_positives`
+ divided by the sum of `true_positives` and `false_positives`.
+ update_op: `Operation` that increments `true_positives` and
+ `false_positives` variables appropriately, and whose value matches
+ `precision`.
+
+ Raises:
+ ValueError: If `ignore_mask` is not `None` and its shape doesn't match
+ `predictions`, or if `weights` is not `None` and its shape doesn't match
+ `predictions`, or if either `metrics_collections` or `updates_collections`
+ are not a list or tuple.
+ ValueError: If `top_k_predictions` has rank < 2.
+ """
+ default_name = _at_k_name('precision', class_id=class_id)
+ with ops.name_scope(
+ name, default_name,
+ (top_k_predictions, labels, ignore_mask, weights)) as scope:
+ rank = array_ops.rank(top_k_predictions)
+ check_rank_op = control_flow_ops.Assert(
+ math_ops.greater_equal(rank, 2),
+ ['top_k_predictions must have rank 2 or higher, e.g. [batch_size, k].'])
+ with ops.control_dependencies([check_rank_op]):
+ return _streaming_sparse_precision_at_k(
+ top_k_idx=top_k_predictions,
+ labels=labels,
+ class_id=class_id,
+ ignore_mask=ignore_mask,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=scope)
def num_relevant(labels, k):
@@ -1869,7 +2037,7 @@ def _sparse_true_positive_at_k(predictions_idx,
def _streaming_sparse_true_positive_at_k(predictions_idx,
labels,
- k,
+ k=None,
class_id=None,
weights=None,
name=None):
@@ -1957,7 +2125,7 @@ def _sparse_false_positive_at_k(predictions_idx,
def _streaming_sparse_false_positive_at_k(predictions_idx,
labels,
- k,
+ k=None,
class_id=None,
weights=None,
name=None):
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 6465f69b25..1bb4f53cec 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1752,6 +1752,36 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
+ def _test_streaming_sparse_precision_at_top_k(self,
+ top_k_predictions,
+ labels,
+ expected,
+ class_id=None,
+ ignore_mask=None,
+ weights=None):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ if ignore_mask is not None:
+ ignore_mask = tf.constant(ignore_mask, tf.bool)
+ if weights is not None:
+ weights = tf.constant(weights, tf.float32)
+ metric, update = metrics.streaming_sparse_precision_at_top_k(
+ top_k_predictions=tf.constant(top_k_predictions, tf.int32),
+ labels=labels, class_id=class_id, ignore_mask=ignore_mask,
+ weights=weights)
+
+ # Fails without initialized vars.
+ self.assertRaises(tf.OpError, metric.eval)
+ self.assertRaises(tf.OpError, update.eval)
+ tf.initialize_variables(tf.local_variables()).run()
+
+ # Run per-step op and assert expected values.
+ if math.isnan(expected):
+ self.assertTrue(math.isnan(update.eval()))
+ self.assertTrue(math.isnan(metric.eval()))
+ else:
+ self.assertEqual(expected, update.eval())
+ self.assertEqual(expected, metric.eval())
+
def _test_sparse_average_precision_at_k(self,
predictions,
labels,
@@ -1789,14 +1819,31 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
self.assertAlmostEqual(expected, update.eval())
self.assertAlmostEqual(expected, metric.eval())
+ def test_top_k_rank_invalid(self):
+ with self.test_session():
+ # top_k_predictions has rank < 2.
+ top_k_predictions = [9, 4, 6, 2, 0]
+ sp_labels = tf.SparseTensorValue(
+ indices=np.array([[0,], [1,], [2,]], np.int64),
+ values=np.array([2, 7, 8], np.int64),
+ shape=np.array([10,], np.int64))
+
+ with self.assertRaises(ValueError):
+ precision, _ = metrics.streaming_sparse_precision_at_top_k(
+ top_k_predictions=tf.constant(top_k_predictions, tf.int64),
+ labels=sp_labels)
+ tf.initialize_variables(tf.local_variables()).run()
+ precision.eval()
+
def test_average_precision(self):
# Example 1.
# Matches example here:
# fastml.com/what-you-wanted-to-know-about-mean-average-precision
labels_ex1 = (0, 1, 2, 3, 4)
labels = np.array([labels_ex1], dtype=np.int64)
- predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3) # [5, 3, 6, 1, 2]
+ predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
predictions = (predictions_ex1,)
+ predictions_top_k_ex1 = (5, 3, 6, 0, 1, 2)
precision_ex1 = (
0.0 / 1,
1.0 / 2,
@@ -1813,6 +1860,8 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
k = i + 1
self._test_streaming_sparse_precision_at_k(
predictions, labels, k, expected=precision_ex1[i])
+ self._test_streaming_sparse_precision_at_top_k(
+ (predictions_top_k_ex1[:k],), labels, expected=precision_ex1[i])
self._test_sparse_average_precision_at_k(
predictions, labels, k, expected=[avg_precision_ex1[i]])
self._test_streaming_sparse_average_precision_at_k(
@@ -1821,8 +1870,9 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
# Example 2.
labels_ex2 = (0, 2, 4, 5, 6)
labels = np.array([labels_ex2], dtype=np.int64)
- predictions_ex2 = (0.3, 0.5, 0.0, 0.4, 0.0, 0.1, 0.2) # [1, 3, 0, 6, 5]
+ predictions_ex2 = (0.3, 0.5, 0.0, 0.4, 0.0, 0.1, 0.2)
predictions = (predictions_ex2,)
+ predictions_top_k_ex2 = (1, 3, 0, 6, 5)
precision_ex2 = (
0.0 / 1,
0.0 / 2,
@@ -1839,6 +1889,8 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
k = i + 1
self._test_streaming_sparse_precision_at_k(
predictions, labels, k, expected=precision_ex2[i])
+ self._test_streaming_sparse_precision_at_top_k(
+ (predictions_top_k_ex2[:k],), labels, expected=precision_ex2[i])
self._test_sparse_average_precision_at_k(
predictions, labels, k, expected=[avg_precision_ex2[i]])
self._test_streaming_sparse_average_precision_at_k(
@@ -1860,6 +1912,9 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
k = i + 1
self._test_streaming_sparse_precision_at_k(
predictions, labels, k, expected=streaming_precision[i])
+ predictions_top_k = (predictions_top_k_ex1[:k], predictions_top_k_ex2[:k])
+ self._test_streaming_sparse_precision_at_top_k(
+ predictions_top_k, labels, expected=streaming_precision[i])
self._test_sparse_average_precision_at_k(
predictions, labels, k, expected=average_precision[i])
self._test_streaming_sparse_average_precision_at_k(
@@ -1881,8 +1936,9 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
"""Tests that labels outside the [0, n_classes) range are ignored."""
labels_ex1 = (-1, 0, 1, 2, 3, 4, 7)
labels = np.array([labels_ex1], dtype=np.int64)
- predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3) # [5, 3, 6, 1, 2]
+ predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
predictions = (predictions_ex1,)
+ predictions_top_k_ex1 = (5, 3, 6, 0, 1, 2)
precision_ex1 = (
0.0 / 1,
1.0 / 2,
@@ -1899,6 +1955,8 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
k = i + 1
self._test_streaming_sparse_precision_at_k(
predictions, labels, k, expected=precision_ex1[i])
+ self._test_streaming_sparse_precision_at_top_k(
+ (predictions_top_k_ex1[:k],), labels, expected=precision_ex1[i])
self._test_sparse_average_precision_at_k(
predictions, labels, k, expected=[avg_precision_ex1[i]])
self._test_streaming_sparse_average_precision_at_k(
@@ -1906,6 +1964,7 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -1915,9 +1974,12 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
for class_id in (-1, 0, 1, 2, 4):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=1, expected=NAN, class_id=class_id)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN, class_id=class_id)
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@@ -1926,16 +1988,24 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=1, expected=1.0 / 2, class_id=3)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=1.0 / 2, class_id=3)
# All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=1, expected=1.0 / 2)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=1.0 / 2)
def test_three_labels_at_k5_no_predictions(self):
predictions = [
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]
]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value([
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -1947,12 +2017,18 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
for class_id in (-1, 1, 3, 8, 10):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN, class_id=class_id)
def test_three_labels_at_k5_no_labels(self):
predictions = [
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]
]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value([
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -1964,12 +2040,18 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
for class_id in (0, 4, 6, 9):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=0.0, class_id=class_id)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=0.0, class_id=class_id)
def test_three_labels_at_k5(self):
predictions = [
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]
]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sparse_labels = _binary_2d_label_to_sparse_value([
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -1981,18 +2063,26 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=2.0 / 2,
class_id=2)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, 1 correct prediction.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=1.0 / 1, class_id=5)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, 1 incorrect prediction.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=7)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=0.0 / 1, class_id=7)
# All classes: 10 predictions, 3 correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=3.0 / 10)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=3.0 / 10)
def test_three_labels_at_k5_some_out_of_range(self):
"""Tests that labels outside the [0, n_classes) range are ignored."""
@@ -2000,6 +2090,10 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]
]
+ top_k_predictions = [
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ]
sp_labels = tf.SparseTensorValue(
indices=[[0, 0], [0, 1], [0, 2], [0, 3],
[1, 0], [1, 1], [1, 2], [1, 3]],
@@ -2011,18 +2105,26 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
# Class 2: 2 labels, 2 correct predictions.
self._test_streaming_sparse_precision_at_k(
predictions, sp_labels, k=5, expected=2.0 / 2, class_id=2)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, sp_labels, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, 1 correct prediction.
self._test_streaming_sparse_precision_at_k(
predictions, sp_labels, k=5, expected=1.0 / 1, class_id=5)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, sp_labels, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, 1 incorrect prediction.
self._test_streaming_sparse_precision_at_k(
predictions, sp_labels, k=5, expected=0.0 / 1, class_id=7)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, sp_labels, expected=0.0 / 1, class_id=7)
# All classes: 10 predictions, 3 correct.
self._test_streaming_sparse_precision_at_k(
predictions, sp_labels, k=5, expected=3.0 / 10)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, sp_labels, expected=3.0 / 10)
def test_3d_nan(self):
predictions = [[
@@ -2032,6 +2134,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]
]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value([[
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -2044,6 +2153,8 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
for class_id in (-1, 1, 3, 8, 10):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN, class_id=class_id)
def test_3d_no_labels(self):
predictions = [[
@@ -2053,6 +2164,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]
]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value([[
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -2065,6 +2183,8 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
for class_id in (0, 4, 6, 9):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=0.0, class_id=class_id)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=0.0, class_id=class_id)
def test_3d(self):
predictions = [[
@@ -2074,6 +2194,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]
]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value([[
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -2085,18 +2212,26 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
# Class 2: 4 predictions, all correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=4.0 / 4, class_id=2)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=4.0 / 4, class_id=2)
# Class 5: 2 predictions, both correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=5)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=2.0 / 2, class_id=5)
# Class 7: 2 predictions, 1 correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=1.0 / 2, class_id=7)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=1.0 / 2, class_id=7)
# All classes: 20 predictions, 7 correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=7.0 / 20)
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=7.0 / 20)
def test_3d_ignore_all(self):
predictions = [[
@@ -2106,6 +2241,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]
]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value([[
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -2118,14 +2260,26 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id,
weights=[[0], [0]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN, class_id=class_id,
+ weights=[[0], [0]])
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id,
weights=[[0, 0], [0, 0]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN, class_id=class_id,
+ weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, ignore_mask=[[False], [True]],
weights=[[0], [1]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN,
+ ignore_mask=[[False], [True]], weights=[[0], [1]])
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN,
+ weights=[[0, 0], [0, 0]])
def test_3d_ignore_some(self):
predictions = [[
@@ -2135,6 +2289,13 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]
]]
+ top_k_predictions = [[
+ [9, 4, 6, 2, 0],
+ [5, 7, 2, 9, 6],
+ ], [
+ [5, 7, 2, 9, 6],
+ [9, 4, 6, 2, 0],
+ ]]
labels = _binary_3d_label_to_sparse_value([[
[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]
@@ -2147,31 +2308,49 @@ class StreamingSparsePrecisionTest(tf.test.TestCase):
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[1], [0]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=2.0 / 2.0, class_id=2,
+ ignore_mask=[[False], [False]], weights=[[1], [0]])
# Class 2: 2 predictions, both correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
ignore_mask=[[False], [False]], weights=[[0], [1]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=2.0 / 2.0, class_id=2,
+ ignore_mask=[[False], [False]], weights=[[0], [1]])
# Class 7: 1 incorrect prediction.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
ignore_mask=[[False], [True]], weights=[[1], [1]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=0.0 / 1.0, class_id=7,
+ ignore_mask=[[False], [True]], weights=[[1], [1]])
# Class 7: 1 correct prediction.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
ignore_mask=[[True], [False]], weights=[[1], [1]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=1.0 / 1.0, class_id=7,
+ ignore_mask=[[True], [False]], weights=[[1], [1]])
# Class 7: no predictions.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=NAN, class_id=7,
weights=[[1, 0], [0, 1]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=NAN, class_id=7,
+ weights=[[1, 0], [0, 1]])
# Class 7: 2 predictions, 1 correct.
self._test_streaming_sparse_precision_at_k(
predictions, labels, k=5, expected=1.0 / 2.0, class_id=7,
weights=[[0, 1], [1, 0]])
+ self._test_streaming_sparse_precision_at_top_k(
+ top_k_predictions, labels, expected=1.0 / 2.0, class_id=7,
+ weights=[[0, 1], [1, 0]])
def test_sparse_tensor_value(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]