aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-01 16:02:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-01 16:06:05 -0700
commit70698a168669e0335872ce9248a6c496328d7871 (patch)
tree943e4f9486f67cb0165200848d73a7fcf0205250 /tensorflow/contrib/metrics
parent36a4b6c815559a583da093a5c19bce5494f6f66d (diff)
Adding streaming_recall_at_precision to Tensorflow contrib metrics.
PiperOrigin-RevId: 174250716
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py113
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py117
3 files changed, 229 insertions, 3 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index bb566f6902..302042c4dd 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -66,6 +66,7 @@ See the @{$python/contrib.metrics} guide.
@@set_size
@@set_union
@@count
+@@recall_at_precision
"""
from __future__ import absolute_import
@@ -80,6 +81,7 @@ from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histog
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
from tensorflow.contrib.metrics.python.ops.metric_ops import count
+from tensorflow.contrib.metrics.python.ops.metric_ops import recall_at_precision
from tensorflow.contrib.metrics.python.ops.metric_ops import sparse_recall_at_top_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index fbb030348c..ca4dcef8de 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -38,6 +38,9 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.util.deprecation import deprecated
+# Epsilon constant used to represent extremely small quantity.
+_EPSILON = 1e-7
+
def _safe_div(numerator, denominator, name):
"""Divides two values, returning 0 if the denominator is <= 0.
@@ -1061,7 +1064,7 @@ def streaming_curve_points(labels=None,
(labels, predictions, weights)):
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
- kepsilon = 1e-7 # to account for floating point imprecisions
+ kepsilon = _EPSILON # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
@@ -1654,7 +1657,7 @@ def streaming_false_positive_rate_at_thresholds(predictions,
predictions, labels, thresholds, weights, includes=('fp', 'tn'))
# Avoid division by zero.
- epsilon = 1e-7
+ epsilon = _EPSILON
def compute_fpr(fp, tn, name):
return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name)
@@ -1725,7 +1728,7 @@ def streaming_false_negative_rate_at_thresholds(predictions,
predictions, labels, thresholds, weights, includes=('fn', 'tp'))
# Avoid division by zero.
- epsilon = 1e-7
+ epsilon = _EPSILON
def compute_fnr(fn, tp, name):
return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name)
@@ -2153,6 +2156,109 @@ def sparse_recall_at_top_k(labels,
name=name_scope)
+def _compute_recall_at_precision(tp, fp, fn, precision, name):
+ """Helper function to compute recall at a given `precision`.
+
+ Args:
+ tp: The number of true positives.
+ fp: The number of false positives.
+ fn: The number of false negatives.
+ precision: The precision for which the recall will be calculated.
+ name: An optional variable_scope name.
+
+ Returns:
+ The recall at a the given `precision`.
+ """
+ precisions = math_ops.div(tp, tp + fp + _EPSILON)
+ tf_index = math_ops.argmin(
+ math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+
+ # Now, we have the implicit threshold, so compute the recall:
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
+
+
+def recall_at_precision(labels,
+ predictions,
+ precision,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes `recall` at `precision`.
+
+ The `recall_at_precision` function creates four local variables,
+ `tp` (true positives), `fp` (false positives) and `fn` (false negatives)
+ that are used to compute the `recall` at the given `precision` value. The
+ threshold for the given `precision` value is computed and used to evaluate the
+ corresponding `recall`.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `recall`. `update_op` increments the `tp`, `fp` and `fn` counts with the
+ weight of each case found in the `predictions` and `labels`.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ labels: The ground truth values, a `Tensor` whose dimensions must match
+ `predictions`. Will be cast to `bool`.
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
+ are in the range `[0, 1]`.
+ precision: A scalar value in range `[0, 1]`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `labels` dimension).
+ num_thresholds: The number of thresholds to use for matching the given
+ `precision`.
+ metrics_collections: An optional list of collections that `recall`
+ should be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ recall: A scalar `Tensor` representing the recall at the given
+ `precision` value.
+ update_op: An operation that increments the `tp`, `fp` and `fn`
+ variables appropriately and whose value matches `recall`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ `precision` is not between 0 and 1, or if either `metrics_collections`
+ or `updates_collections` are not a list or tuple.
+
+ """
+ if not 0 <= precision <= 1:
+ raise ValueError('`precision` must be in the range [0, 1].')
+
+ with variable_scope.variable_scope(name, 'recall_at_precision',
+ (predictions, labels, weights)):
+ thresholds = [
+ i * 1.0 / (num_thresholds - 1) for i in range(1, num_thresholds - 1)
+ ]
+ thresholds = [0.0 - _EPSILON] + thresholds + [1.0 + _EPSILON]
+
+ values, update_ops = _streaming_confusion_matrix_at_thresholds(
+ labels, predictions, thresholds, weights)
+
+ recall = _compute_recall_at_precision(values['tp'], values['fp'],
+ values['fn'], precision, 'value')
+ update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'],
+ update_ops['fn'], precision,
+ 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, recall)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return recall, update_op
+
+
def streaming_sparse_average_precision_at_k(predictions,
labels,
k,
@@ -3168,6 +3274,7 @@ __all__ = [
'aggregate_metric_map',
'aggregate_metrics',
'count',
+ 'recall_at_precision',
'sparse_recall_at_top_k',
'streaming_accuracy',
'streaming_auc',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index ad4741b350..6a8e58b4da 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2917,6 +2917,123 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(expected_fpr, fpr.eval(), 2)
+class RecallAtPrecisionTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.recall_at_precision(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ precision=0.7)
+ _assert_metric_variables(self, ('recall_at_precision/true_positives:0',
+ 'recall_at_precision/false_negatives:0',
+ 'recall_at_precision/false_positives:0',
+ 'recall_at_precision/true_negatives:0'))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ mean, _ = metrics.recall_at_precision(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ precision=0.7,
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [mean])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.recall_at_precision(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ precision=0.7,
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
+ recall, update_op = metrics.recall_at_precision(
+ predictions, labels, precision=0.7)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(update_op)
+
+ # Then verify idempotency.
+ initial_recall = recall.eval()
+ for _ in range(10):
+ self.assertAlmostEqual(initial_recall, recall.eval(), 5)
+
+ def testAllCorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(inputs)
+ recall, update_op = metrics.recall_at_precision(
+ predictions, labels, precision=1.0)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertEqual(1, sess.run(update_op))
+ self.assertEqual(1, recall.eval())
+
+ def testSomeCorrectHighPrecision(self):
+ predictions_values = [1, .9, .8, .7, .6, .5, .4, .3]
+ labels_values = [1, 1, 1, 1, 0, 0, 0, 1]
+
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ recall, update_op = metrics.recall_at_precision(
+ predictions, labels, precision=0.8)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(0.8, sess.run(update_op))
+ self.assertAlmostEqual(0.8, recall.eval())
+
+ def testSomeCorrectLowPrecision(self):
+ predictions_values = [1, .9, .8, .7, .6, .5, .4, .3, .2, .1]
+ labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1]
+
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ recall, update_op = metrics.recall_at_precision(
+ predictions, labels, precision=0.4)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ target_recall = 2.0 / 3.0
+ self.assertAlmostEqual(target_recall, sess.run(update_op))
+ self.assertAlmostEqual(target_recall, recall.eval())
+
+ def testWeighted(self):
+ predictions_values = [1, .9, .8, .7, .6]
+ labels_values = [1, 1, 0, 0, 1]
+ weights_values = [1, 1, 3, 4, 1]
+
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ weights = constant_op.constant(weights_values)
+ recall, update_op = metrics.recall_at_precision(
+ predictions, labels, weights=weights, precision=0.4)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ target_recall = 2.0 / 3.0
+ self.assertAlmostEqual(target_recall, sess.run(update_op))
+ self.assertAlmostEqual(target_recall, recall.eval())
+
+
class StreamingFNRThresholdsTest(test.TestCase):
def setUp(self):