aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-18 14:38:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 14:40:50 -0700
commit325ba9ece698d04082b173ba300a10623d27de96 (patch)
treec84d5066572768286050ea4c8d16f4ccf14767fb /tensorflow/contrib/metrics
parentb75e1204d3aaab20d7a937edd6b2f05ff5785827 (diff)
Adds an implementation of the precision at recall metric.
PiperOrigin-RevId: 193418737
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py115
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py132
3 files changed, 249 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index de02dc8f45..5effea3596 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -71,6 +71,7 @@ See the @{$python/contrib.metrics} guide.
@@count
@@precision_recall_at_equal_thresholds
@@recall_at_precision
+@@precision_at_recall
"""
from __future__ import absolute_import
@@ -87,6 +88,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
from tensorflow.contrib.metrics.python.ops.metric_ops import auc_with_confidence_intervals
from tensorflow.contrib.metrics.python.ops.metric_ops import cohen_kappa
from tensorflow.contrib.metrics.python.ops.metric_ops import count
+from tensorflow.contrib.metrics.python.ops.metric_ops import precision_at_recall
from tensorflow.contrib.metrics.python.ops.metric_ops import precision_recall_at_equal_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import recall_at_precision
from tensorflow.contrib.metrics.python.ops.metric_ops import sparse_recall_at_top_k
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 9c8ae48094..5364e3075d 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2588,6 +2588,121 @@ def recall_at_precision(labels,
return recall, update_op
+def precision_at_recall(labels,
+ predictions,
+ target_recall,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ """Computes the precision at a given recall.
+
+ This function creates variables to track the true positives, false positives,
+ true negatives, and false negatives at a set of thresholds. Among those
+ thresholds where recall is at least `target_recall`, precision is computed
+ at the threshold where recall is closest to `target_recall`.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ precision at `target_recall`. `update_op` increments the counts of true
+ positives, false positives, true negatives, and false negatives with the
+ weight of each case found in the `predictions` and `labels`.
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ For additional information about precision and recall, see
+ http://en.wikipedia.org/wiki/Precision_and_recall
+
+ Args:
+ labels: The ground truth values, a `Tensor` whose dimensions must match
+ `predictions`. Will be cast to `bool`.
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
+ are in the range `[0, 1]`.
+ target_recall: A scalar value in range `[0, 1]`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `labels` dimension).
+ num_thresholds: The number of thresholds to use for matching the given
+ recall.
+ metrics_collections: An optional list of collections to which `precision`
+ should be added.
+ updates_collections: An optional list of collections to which `update_op`
+ should be added.
+ name: An optional variable_scope name.
+
+ Returns:
+ precision: A scalar `Tensor` representing the precision at the given
+ `target_recall` value.
+ update_op: An operation that increments the variables for tracking the
+ true positives, false positives, true negatives, and false negatives and
+ whose value matches `precision`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ `target_recall` is not between 0 and 1, or if either `metrics_collections`
+ or `updates_collections` are not a list or tuple.
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ raise RuntimeError('tf.metrics.precision_at_recall is not '
+ 'supported when eager execution is enabled.')
+
+ if target_recall < 0 or target_recall > 1:
+ raise ValueError('`target_recall` must be in the range [0, 1].')
+
+ with variable_scope.variable_scope(name, 'precision_at_recall',
+ (predictions, labels, weights)):
+ kepsilon = 1e-7 # Used to avoid division by zero.
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
+
+ values, update_ops = _streaming_confusion_matrix_at_thresholds(
+ predictions, labels, thresholds, weights)
+
+ def compute_precision_at_recall(tp, fp, fn, name):
+ """Computes the precision at a given recall.
+
+ Args:
+ tp: True positives.
+ fp: False positives.
+ fn: False negatives.
+ name: A name for the operation.
+
+ Returns:
+ The precision at the desired recall.
+ """
+ recalls = math_ops.div(tp, tp + fn + kepsilon)
+
+ # Because recall is monotone decreasing as a function of the threshold,
+ # the smallest recall exceeding target_recall occurs at the largest
+ # threshold where recall >= target_recall.
+ admissible_recalls = math_ops.cast(
+ math_ops.greater_equal(recalls, target_recall), dtypes.int64)
+ tf_index = math_ops.reduce_sum(admissible_recalls) - 1
+
+ # Now we have the threshold at which to compute precision:
+ return math_ops.div(tp[tf_index] + kepsilon,
+ tp[tf_index] + fp[tf_index] + kepsilon,
+ name)
+
+ precision_value = compute_precision_at_recall(
+ values['tp'], values['fp'], values['fn'], 'value')
+ update_op = compute_precision_at_recall(
+ update_ops['tp'], update_ops['fp'], update_ops['fn'], 'update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, precision_value)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return precision_value, update_op
+
+
def streaming_sparse_average_precision_at_k(predictions,
labels,
k,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 33eb655fb6..76420db8bd 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -3380,6 +3380,138 @@ class RecallAtPrecisionTest(test.TestCase):
self.assertAlmostEqual(target_recall, recall.eval())
+class PrecisionAtRecallTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.precision_at_recall(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ target_recall=0.7)
+ _assert_metric_variables(self,
+ ('precision_at_recall/true_positives:0',
+ 'precision_at_recall/false_negatives:0',
+ 'precision_at_recall/false_positives:0',
+ 'precision_at_recall/true_negatives:0'))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ mean, _ = metrics.precision_at_recall(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ target_recall=0.7,
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [mean])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.precision_at_recall(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ target_recall=0.7,
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1)
+ precision, update_op = metrics.precision_at_recall(
+ labels, predictions, target_recall=0.7)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(update_op)
+
+ # Then verify idempotency.
+ initial_precision = precision.eval()
+ for _ in range(10):
+ self.assertAlmostEqual(initial_precision, precision.eval(), places=5)
+
+ def testAllCorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(inputs)
+ precision, update_op = metrics.precision_at_recall(
+ labels, predictions, target_recall=0.7)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertEqual(1, sess.run(update_op))
+ self.assertEqual(1, precision.eval())
+
+ def testAllIncorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = 1.0 - predictions
+ label_prior = math_ops.reduce_mean(labels)
+ precision, update_op = metrics.precision_at_recall(
+ labels, predictions, target_recall=0.2)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertEqual(sess.run(label_prior), sess.run(update_op))
+ self.assertEqual(sess.run(label_prior), precision.eval())
+
+ def testSomeCorrectHighRecall(self):
+ predictions_values = [0.1, 0.2, 0.5, 0.3, 0.0, 0.1, 0.45, 0.5, 0.8, 0.9]
+ labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ precision, update_op = metrics.precision_at_recall(
+ labels, predictions, target_recall=0.8)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(0.8, sess.run(update_op))
+ self.assertAlmostEqual(0.8, precision.eval())
+
+ def testSomeCorrectLowRecall(self):
+ predictions_values = [0.1, 0.2, 0.7, 0.3, 0.0, 0.1, 0.45, 0.5, 0.6, 0.9]
+ labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ precision, update_op = metrics.precision_at_recall(
+ labels, predictions, target_recall=0.4)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(2.0/3, sess.run(update_op))
+ self.assertAlmostEqual(2.0/3, precision.eval())
+
+ def testWeighted_multipleLabelDtypes(self):
+ for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
+ predictions_values = [
+ 0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.22, 0.25, 0.31, 0.35]
+ labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+ weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = math_ops.cast(labels_values, dtype=label_dtype)
+ weights = constant_op.constant(weights_values)
+ precision, update_op = metrics.precision_at_recall(
+ labels, predictions, target_recall=0.8, weights=weights)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(34.0/43, sess.run(update_op))
+ self.assertAlmostEqual(34.0/43, precision.eval())
+
+
class StreamingFNRThresholdsTest(test.TestCase):
def setUp(self):