aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-14 01:54:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-14 01:59:38 -0700
commita3667c483ebf839653d71cf42a0a71196a513dc9 (patch)
tree09cf1f454118d4a1af164b8486aae6faa901cd5b /tensorflow/contrib/metrics
parent860f8c50753bcbfca8243c585033b3d44c4b7c7f (diff)
Add streaming_false_{negative,positive}_rate and streaming_false_{negative,positive}_rate_at_thresholds.
PiperOrigin-RevId: 172191462
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/__init__.py8
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py347
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py720
3 files changed, 1075 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index a9bce65e55..2c48882d0e 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -22,6 +22,10 @@ See the @{$python/contrib.metrics} guide.
@@streaming_recall_at_thresholds
@@streaming_precision
@@streaming_precision_at_thresholds
+@@streaming_false_positive_rate
+@@streaming_false_positive_rate_at_thresholds
+@@streaming_false_negative_rate
+@@streaming_false_negative_rate_at_thresholds
@@streaming_auc
@@streaming_curve_points
@@streaming_recall_at_k
@@ -80,8 +84,12 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 76986d0156..85c8e9038a 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -565,6 +565,213 @@ def streaming_recall(predictions, labels, weights=None,
updates_collections=updates_collections, name=name)
+def _true_negatives(labels, predictions, weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Sum the weights of true negatives.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ labels: The ground truth values, a `Tensor` whose dimensions must match
+ `predictions`. Will be cast to `bool`.
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
+ be cast to `bool`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `labels` dimension).
+ metrics_collections: An optional list of collections that the metric
+ value variable should be added to.
+ updates_collections: An optional list of collections that the metric update
+ ops should be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ value_tensor: A `Tensor` representing the current value of the metric.
+ update_op: An operation that accumulates the error from a batch of data.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ either `metrics_collections` or `updates_collections` are not a list or
+ tuple.
+ """
+ with variable_scope.variable_scope(
+ name, 'true_negatives', (predictions, labels, weights)):
+
+ predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
+ weights=weights)
+ is_true_negative = math_ops.logical_and(math_ops.equal(labels, False),
+ math_ops.equal(predictions, False))
+ return _count_condition(is_true_negative, weights, metrics_collections,
+ updates_collections)
+
+
+def streaming_false_positive_rate(predictions, labels, weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes the false positive rate of predictions with respect to labels.
+
+ The `false_positive_rate` function creates two local variables,
+ `false_positives` and `true_negatives`, that are used to compute the
+ false positive rate. This value is ultimately returned as
+ `false_positive_rate`, an idempotent operation that simply divides
+ `false_positives` by the sum of `false_positives` and `true_negatives`.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `false_positive_rate`. `update_op` weights each prediction by the
+ corresponding value in `weights`.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
+ be cast to `bool`.
+ labels: The ground truth values, a `Tensor` whose dimensions must match
+ `predictions`. Will be cast to `bool`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `labels` dimension).
+ metrics_collections: An optional list of collections that
+ `false_positive_rate` should be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ false_positive_rate: Scalar float `Tensor` with the value of
+ `false_positives` divided by the sum of `false_positives` and
+ `true_negatives`.
+ update_op: `Operation` that increments `false_positives` and
+ `true_negatives` variables appropriately and whose value matches
+ `false_positive_rate`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ either `metrics_collections` or `updates_collections` are not a list or
+ tuple.
+ """
+ with variable_scope.variable_scope(
+ name, 'false_positive_rate', (predictions, labels, weights)):
+ predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
+ weights=weights)
+
+ false_p, false_positives_update_op = metrics.false_positives(
+ labels, predictions, weights, metrics_collections=None,
+ updates_collections=None, name=None)
+ true_n, true_negatives_update_op = _true_negatives(
+ labels, predictions, weights, metrics_collections=None,
+ updates_collections=None, name=None)
+
+ def compute_fpr(fp, tn, name):
+ return array_ops.where(
+ math_ops.greater(fp + tn, 0),
+ math_ops.div(fp, fp + tn),
+ 0,
+ name)
+
+ fpr = compute_fpr(false_p, true_n, 'value')
+ update_op = compute_fpr(
+ false_positives_update_op, true_negatives_update_op, 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, fpr)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return fpr, update_op
+
+
+def streaming_false_negative_rate(predictions, labels, weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes the false negative rate of predictions with respect to labels.
+
+ The `false_negative_rate` function creates two local variables,
+ `false_negatives` and `true_positives`, that are used to compute the
+ false positive rate. This value is ultimately returned as
+ `false_negative_rate`, an idempotent operation that simply divides
+ `false_negatives` by the sum of `false_negatives` and `true_positives`.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `false_negative_rate`. `update_op` weights each prediction by the
+ corresponding value in `weights`.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
+ be cast to `bool`.
+ labels: The ground truth values, a `Tensor` whose dimensions must match
+ `predictions`. Will be cast to `bool`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `labels` dimension).
+ metrics_collections: An optional list of collections that
+ `false_negative_rate` should be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ false_negative_rate: Scalar float `Tensor` with the value of
+ `false_negatives` divided by the sum of `false_negatives` and
+ `true_positives`.
+ update_op: `Operation` that increments `false_negatives` and
+ `true_positives` variables appropriately and whose value matches
+ `false_negative_rate`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ either `metrics_collections` or `updates_collections` are not a list or
+ tuple.
+ """
+ with variable_scope.variable_scope(
+ name, 'false_negative_rate', (predictions, labels, weights)):
+ predictions, labels, weights = _remove_squeezable_dimensions(
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
+ weights=weights)
+
+ false_n, false_negatives_update_op = metrics.false_negatives(
+ labels, predictions, weights, metrics_collections=None,
+ updates_collections=None, name=None)
+ true_p, true_positives_update_op = metrics.true_positives(
+ labels, predictions, weights, metrics_collections=None,
+ updates_collections=None, name=None)
+
+ def compute_fnr(fn, tp, name):
+ return array_ops.where(
+ math_ops.greater(fn + tp, 0),
+ math_ops.div(fn, fn + tp),
+ 0,
+ name)
+
+ fnr = compute_fnr(false_n, true_p, 'value')
+ update_op = compute_fnr(
+ false_negatives_update_op, true_positives_update_op, 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, fnr)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return fnr, update_op
+
+
def _streaming_confusion_matrix_at_thresholds(
predictions, labels, thresholds, weights=None, includes=None):
"""Computes true_positives, false_negatives, true_negatives, false_positives.
@@ -1114,6 +1321,142 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds,
updates_collections=updates_collections, name=name)
+def streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds, weights=None, metrics_collections=None,
+ updates_collections=None, name=None):
+ """Computes various fpr values for different `thresholds` on `predictions`.
+
+ The `streaming_false_positive_rate_at_thresholds` function creates two
+ local variables, `false_positives`, `true_negatives`, for various values of
+ thresholds. `false_positive_rate[i]` is defined as the total weight
+ of values in `predictions` above `thresholds[i]` whose corresponding entry in
+ `labels` is `False`, divided by the total weight of `False` values in `labels`
+ (`false_positives[i] / (false_positives[i] + true_negatives[i])`).
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `false_positive_rate`.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
+ are in the range `[0, 1]`.
+ labels: A `bool` `Tensor` whose shape matches `predictions`.
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
+ weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
+ must be broadcastable to `labels` (i.e., all dimensions must be either
+ `1`, or the same as the corresponding `labels` dimension).
+ metrics_collections: An optional list of collections that
+ `false_positive_rate` should be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`.
+ update_op: An operation that increments the `false_positives` and
+ `true_negatives` variables that are used in the computation of
+ `false_positive_rate`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ either `metrics_collections` or `updates_collections` are not a list or
+ tuple.
+ """
+ with variable_scope.variable_scope(
+ name, 'false_positive_rate_at_thresholds',
+ (predictions, labels, weights)):
+ values, update_ops = _streaming_confusion_matrix_at_thresholds(
+ predictions, labels, thresholds, weights, includes=('fp', 'tn'))
+
+ # Avoid division by zero.
+ epsilon = 1e-7
+ def compute_fpr(fp, tn, name):
+ return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name)
+
+ fpr = compute_fpr(values['fp'], values['tn'], 'value')
+ update_op = compute_fpr(
+ update_ops['fp'], update_ops['tn'], 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, fpr)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return fpr, update_op
+
+
+def streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds, weights=None, metrics_collections=None,
+ updates_collections=None, name=None):
+ """Computes various fnr values for different `thresholds` on `predictions`.
+
+ The `streaming_false_negative_rate_at_thresholds` function creates two
+ local variables, `false_negatives`, `true_positives`, for various values of
+ thresholds. `false_negative_rate[i]` is defined as the total weight
+ of values in `predictions` above `thresholds[i]` whose corresponding entry in
+ `labels` is `False`, divided by the total weight of `True` values in `labels`
+ (`false_negatives[i] / (false_negatives[i] + true_positives[i])`).
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `false_positive_rate`.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
+ are in the range `[0, 1]`.
+ labels: A `bool` `Tensor` whose shape matches `predictions`.
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
+ weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and
+ must be broadcastable to `labels` (i.e., all dimensions must be either
+ `1`, or the same as the corresponding `labels` dimension).
+ metrics_collections: An optional list of collections that
+ `false_negative_rate` should be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`.
+ update_op: An operation that increments the `false_negatives` and
+ `true_positives` variables that are used in the computation of
+ `false_negative_rate`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ either `metrics_collections` or `updates_collections` are not a list or
+ tuple.
+ """
+ with variable_scope.variable_scope(
+ name, 'false_negative_rate_at_thresholds',
+ (predictions, labels, weights)):
+ values, update_ops = _streaming_confusion_matrix_at_thresholds(
+ predictions, labels, thresholds, weights, includes=('fn', 'tp'))
+
+ # Avoid division by zero.
+ epsilon = 1e-7
+ def compute_fnr(fn, tp, name):
+ return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name)
+
+ fnr = compute_fnr(values['fn'], values['tp'], 'value')
+ update_op = compute_fnr(
+ update_ops['fn'], update_ops['tp'], 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, fnr)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return fnr, update_op
+
+
def _at_k_name(name, k=None, class_id=None):
if k is not None:
name = '%s_at_%d' % (name, k)
@@ -2479,8 +2822,12 @@ __all__ = [
'streaming_accuracy',
'streaming_auc',
'streaming_curve_points',
+ 'streaming_false_negative_rate',
+ 'streaming_false_negative_rate_at_thresholds',
'streaming_false_negatives',
'streaming_false_negatives_at_thresholds',
+ 'streaming_false_positive_rate',
+ 'streaming_false_positive_rate_at_thresholds',
'streaming_false_positives',
'streaming_false_positives_at_thresholds',
'streaming_mean',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 9b959b43a9..cc0ad155fa 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1355,6 +1355,262 @@ class StreamingRecallTest(test.TestCase):
self.assertEqual(0, recall.eval())
+class StreamingFPRTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.streaming_false_positive_rate(
+ predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
+ _assert_local_variables(self, (
+ 'false_positive_rate/false_positives/count:0',
+ 'false_positive_rate/true_negatives/count:0'))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ mean, _ = metrics.streaming_false_positive_rate(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [mean])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.streaming_false_positive_rate(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ fpr, update_op = metrics.streaming_false_positive_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(update_op)
+
+ # Then verify idempotency.
+ initial_fpr = fpr.eval()
+ for _ in range(10):
+ self.assertEqual(initial_fpr, fpr.eval())
+
+ def testAllCorrect(self):
+ np_inputs = np.random.randint(0, 2, size=(100, 1))
+
+ predictions = constant_op.constant(np_inputs)
+ labels = constant_op.constant(np_inputs)
+ fpr, update_op = metrics.streaming_false_positive_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(0, fpr.eval())
+
+ def testSomeCorrect(self):
+ predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
+ labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ fpr, update_op = metrics.streaming_false_positive_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(0.5, update_op.eval())
+ self.assertAlmostEqual(0.5, fpr.eval())
+
+ def testWeighted1d(self):
+ predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
+ labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
+ weights = constant_op.constant([[2], [5]])
+ fpr, update_op = metrics.streaming_false_positive_rate(
+ predictions, labels, weights=weights)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ weighted_fp = 2.0 + 5.0
+ weighted_f = (2.0 + 2.0) + (5.0 + 5.0)
+ expected_fpr = weighted_fp / weighted_f
+ self.assertAlmostEqual(expected_fpr, update_op.eval())
+ self.assertAlmostEqual(expected_fpr, fpr.eval())
+
+ def testWeighted2d(self):
+ predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
+ labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
+ weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
+ fpr, update_op = metrics.streaming_false_positive_rate(
+ predictions, labels, weights=weights)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ weighted_fp = 1.0 + 3.0
+ weighted_f = (1.0 + 4.0) + (2.0 + 3.0)
+ expected_fpr = weighted_fp / weighted_f
+ self.assertAlmostEqual(expected_fpr, update_op.eval())
+ self.assertAlmostEqual(expected_fpr, fpr.eval())
+
+ def testAllIncorrect(self):
+ np_inputs = np.random.randint(0, 2, size=(100, 1))
+
+ predictions = constant_op.constant(np_inputs)
+ labels = constant_op.constant(1 - np_inputs)
+ fpr, update_op = metrics.streaming_false_positive_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(1, fpr.eval())
+
+ def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self):
+ predictions = array_ops.ones((1, 4))
+ labels = array_ops.ones((1, 4))
+ fpr, update_op = metrics.streaming_false_positive_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(0, fpr.eval())
+
+
+class StreamingFNRTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.streaming_false_negative_rate(
+ predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
+ _assert_local_variables(self, (
+ 'false_negative_rate/false_negatives/count:0',
+ 'false_negative_rate/true_positives/count:0'))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ mean, _ = metrics.streaming_false_negative_rate(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [mean])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.streaming_false_negative_rate(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ fnr, update_op = metrics.streaming_false_negative_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(update_op)
+
+ # Then verify idempotency.
+ initial_fnr = fnr.eval()
+ for _ in range(10):
+ self.assertEqual(initial_fnr, fnr.eval())
+
+ def testAllCorrect(self):
+ np_inputs = np.random.randint(0, 2, size=(100, 1))
+
+ predictions = constant_op.constant(np_inputs)
+ labels = constant_op.constant(np_inputs)
+ fnr, update_op = metrics.streaming_false_negative_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(0, fnr.eval())
+
+ def testSomeCorrect(self):
+ predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
+ labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ fnr, update_op = metrics.streaming_false_negative_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(0.5, update_op.eval())
+ self.assertAlmostEqual(0.5, fnr.eval())
+
+ def testWeighted1d(self):
+ predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
+ labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
+ weights = constant_op.constant([[2], [5]])
+ fnr, update_op = metrics.streaming_false_negative_rate(
+ predictions, labels, weights=weights)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ weighted_fn = 2.0 + 5.0
+ weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
+ expected_fnr = weighted_fn / weighted_t
+ self.assertAlmostEqual(expected_fnr, update_op.eval())
+ self.assertAlmostEqual(expected_fnr, fnr.eval())
+
+ def testWeighted2d(self):
+ predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
+ labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
+ weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
+ fnr, update_op = metrics.streaming_false_negative_rate(
+ predictions, labels, weights=weights)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ weighted_fn = 2.0 + 4.0
+ weighted_t = (2.0 + 3.0) + (1.0 + 4.0)
+ expected_fnr = weighted_fn / weighted_t
+ self.assertAlmostEqual(expected_fnr, update_op.eval())
+ self.assertAlmostEqual(expected_fnr, fnr.eval())
+
+ def testAllIncorrect(self):
+ np_inputs = np.random.randint(0, 2, size=(100, 1))
+
+ predictions = constant_op.constant(np_inputs)
+ labels = constant_op.constant(1 - np_inputs)
+ fnr, update_op = metrics.streaming_false_negative_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(1, fnr.eval())
+
+ def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self):
+ predictions = array_ops.zeros((1, 4))
+ labels = array_ops.zeros((1, 4))
+ fnr, update_op = metrics.streaming_false_negative_rate(
+ predictions, labels)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+ self.assertEqual(0, fnr.eval())
+
+
class StreamingCurvePointsTest(test.TestCase):
def setUp(self):
@@ -2268,6 +2524,470 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(expected_rec, rec.eval(), 2)
+class StreamingFPRThresholdsTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.streaming_false_positive_rate_at_thresholds(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ thresholds=[0, 0.5, 1.0])
+ _assert_local_variables(self, (
+ 'false_positive_rate_at_thresholds/false_positives:0',
+ 'false_positive_rate_at_thresholds/true_negatives:0',))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ fpr, _ = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ thresholds=[0, 0.5, 1.0],
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [fpr])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ thresholds=[0, 0.5, 1.0],
+ updates_collections=[my_collection_name])
+ self.assertListEqual(
+ ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ thresholds = [0, 0.5, 1.0]
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(fpr_op)
+
+ # Then verify idempotency.
+ initial_fpr = fpr.eval()
+ for _ in range(10):
+ self.assertAllClose(initial_fpr, fpr.eval())
+
+ def testAllCorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(inputs)
+ thresholds = [0.5]
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fpr_op)
+
+ self.assertEqual(0, fpr.eval())
+
+ def testSomeCorrect(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ thresholds = [0.5]
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fpr_op)
+
+ self.assertAlmostEqual(0.5, fpr.eval())
+
+ def testAllIncorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
+ thresholds = [0.5]
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fpr_op)
+
+ self.assertAlmostEqual(1, fpr.eval())
+
+ def testWeights1d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32)
+ thresholds = [0.5, 1.1]
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds, weights=weights)
+
+ fpr_low = fpr[0]
+ fpr_high = fpr[1]
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fpr_op)
+
+ self.assertAlmostEqual(0.0, fpr_low.eval(), places=5)
+ self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
+
+ def testWeights2d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32)
+ thresholds = [0.5, 1.1]
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds, weights=weights)
+
+ fpr_low = fpr[0]
+ fpr_high = fpr[1]
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fpr_op)
+
+ self.assertAlmostEqual(0.0, fpr_low.eval(), places=5)
+ self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
+
+ def testExtremeThresholds(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
+ thresholds = [-1.0, 2.0] # lower/higher than any values
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ fpr_low = fpr[0]
+ fpr_high = fpr[1]
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fpr_op)
+
+ self.assertAlmostEqual(1.0, fpr_low.eval(), places=5)
+ self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
+
+ def testZeroLabelsPredictions(self):
+ with self.test_session() as sess:
+ predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
+ labels = array_ops.zeros([4])
+ thresholds = [0.5]
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fpr_op)
+
+ self.assertAlmostEqual(0, fpr.eval(), 6)
+
+ def testWithMultipleUpdates(self):
+ num_samples = 1000
+ batch_size = 10
+ num_batches = int(num_samples / batch_size)
+
+ # Create the labels and data.
+ labels = np.random.randint(0, 2, size=(num_samples, 1))
+ noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
+ predictions = 0.4 + 0.2 * labels + noise
+ predictions[predictions > 1] = 1
+ predictions[predictions < 0] = 0
+ thresholds = [0.3]
+
+ fp = 0
+ tn = 0
+ for i in range(num_samples):
+ if predictions[i] > thresholds[0]:
+ if labels[i] == 0:
+ fp += 1
+ else:
+ if labels[i] == 0:
+ tn += 1
+ epsilon = 1e-7
+ expected_fpr = fp / (epsilon + fp + tn)
+
+ labels = labels.astype(np.float32)
+ predictions = predictions.astype(np.float32)
+
+ with self.test_session() as sess:
+ # Reshape the data so its easy to queue up:
+ predictions_batches = predictions.reshape((batch_size, num_batches))
+ labels_batches = labels.reshape((batch_size, num_batches))
+
+ # Enqueue the data:
+ predictions_queue = data_flow_ops.FIFOQueue(
+ num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
+ labels_queue = data_flow_ops.FIFOQueue(
+ num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
+
+ for i in range(int(num_batches)):
+ tf_prediction = constant_op.constant(predictions_batches[:, i])
+ tf_label = constant_op.constant(labels_batches[:, i])
+ sess.run([
+ predictions_queue.enqueue(tf_prediction),
+ labels_queue.enqueue(tf_label)
+ ])
+
+ tf_predictions = predictions_queue.dequeue()
+ tf_labels = labels_queue.dequeue()
+
+ fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
+ tf_predictions, tf_labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ for _ in range(int(num_samples / batch_size)):
+ sess.run(fpr_op)
+ # Since this is only approximate, we can't expect a 6 digits match.
+ # Although with higher number of samples/thresholds we should see the
+ # accuracy improving
+ self.assertAlmostEqual(expected_fpr, fpr.eval(), 2)
+
+
+class StreamingFNRThresholdsTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.streaming_false_negative_rate_at_thresholds(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ thresholds=[0, 0.5, 1.0])
+ _assert_local_variables(self, (
+ 'false_negative_rate_at_thresholds/false_negatives:0',
+ 'false_negative_rate_at_thresholds/true_positives:0',))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ fnr, _ = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ thresholds=[0, 0.5, 1.0],
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [fnr])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ thresholds=[0, 0.5, 1.0],
+ updates_collections=[my_collection_name])
+ self.assertListEqual(
+ ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ thresholds = [0, 0.5, 1.0]
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(fnr_op)
+
+ # Then verify idempotency.
+ initial_fnr = fnr.eval()
+ for _ in range(10):
+ self.assertAllClose(initial_fnr, fnr.eval())
+
+ def testAllCorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(inputs)
+ thresholds = [0.5]
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fnr_op)
+
+ self.assertEqual(0, fnr.eval())
+
+ def testSomeCorrect(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ thresholds = [0.5]
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fnr_op)
+
+ self.assertAlmostEqual(0.5, fnr.eval())
+
+ def testAllIncorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
+ thresholds = [0.5]
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fnr_op)
+
+ self.assertAlmostEqual(1, fnr.eval())
+
+ def testWeights1d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32)
+ thresholds = [0.5, 1.1]
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds, weights=weights)
+
+ fnr_low = fnr[0]
+ fnr_high = fnr[1]
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fnr_op)
+
+ self.assertAlmostEqual(0.0, fnr_low.eval(), places=5)
+ self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
+
+ def testWeights2d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32)
+ thresholds = [0.5, 1.1]
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds, weights=weights)
+
+ fnr_low = fnr[0]
+ fnr_high = fnr[1]
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fnr_op)
+
+ self.assertAlmostEqual(0.0, fnr_low.eval(), places=5)
+ self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
+
+ def testExtremeThresholds(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
+ thresholds = [-1.0, 2.0] # lower/higher than any values
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ fnr_low = fnr[0]
+ fnr_high = fnr[1]
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fnr_op)
+
+ self.assertAlmostEqual(0.0, fnr_low.eval())
+ self.assertAlmostEqual(1.0, fnr_high.eval())
+
+ def testZeroLabelsPredictions(self):
+ with self.test_session() as sess:
+ predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
+ labels = array_ops.zeros([4])
+ thresholds = [0.5]
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ predictions, labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(fnr_op)
+
+ self.assertAlmostEqual(0, fnr.eval(), 6)
+
+ def testWithMultipleUpdates(self):
+ num_samples = 1000
+ batch_size = 10
+ num_batches = int(num_samples / batch_size)
+
+ # Create the labels and data.
+ labels = np.random.randint(0, 2, size=(num_samples, 1))
+ noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
+ predictions = 0.4 + 0.2 * labels + noise
+ predictions[predictions > 1] = 1
+ predictions[predictions < 0] = 0
+ thresholds = [0.3]
+
+ fn = 0
+ tp = 0
+ for i in range(num_samples):
+ if predictions[i] > thresholds[0]:
+ if labels[i] == 1:
+ tp += 1
+ else:
+ if labels[i] == 1:
+ fn += 1
+ epsilon = 1e-7
+ expected_fnr = fn / (epsilon + fn + tp)
+
+ labels = labels.astype(np.float32)
+ predictions = predictions.astype(np.float32)
+
+ with self.test_session() as sess:
+ # Reshape the data so its easy to queue up:
+ predictions_batches = predictions.reshape((batch_size, num_batches))
+ labels_batches = labels.reshape((batch_size, num_batches))
+
+ # Enqueue the data:
+ predictions_queue = data_flow_ops.FIFOQueue(
+ num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
+ labels_queue = data_flow_ops.FIFOQueue(
+ num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
+
+ for i in range(int(num_batches)):
+ tf_prediction = constant_op.constant(predictions_batches[:, i])
+ tf_label = constant_op.constant(labels_batches[:, i])
+ sess.run([
+ predictions_queue.enqueue(tf_prediction),
+ labels_queue.enqueue(tf_label)
+ ])
+
+ tf_predictions = predictions_queue.dequeue()
+ tf_labels = labels_queue.dequeue()
+
+ fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
+ tf_predictions, tf_labels, thresholds)
+
+ sess.run(variables.local_variables_initializer())
+ for _ in range(int(num_samples / batch_size)):
+ sess.run(fnr_op)
+ # Since this is only approximate, we can't expect a 6 digits match.
+ # Although with higher number of samples/thresholds we should see the
+ # accuracy improving
+ self.assertAlmostEqual(expected_fnr, fnr.eval(), 2)
+
+
# TODO(ptucker): Remove when we remove `streaming_recall_at_k`.
# This op will be deprecated soon in favor of `streaming_sparse_recall_at_k`.
# Until then, this test validates that both ops yield the same results.