diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-01 16:02:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-01 16:06:05 -0700 |
commit | 70698a168669e0335872ce9248a6c496328d7871 (patch) | |
tree | 943e4f9486f67cb0165200848d73a7fcf0205250 /tensorflow/contrib/metrics | |
parent | 36a4b6c815559a583da093a5c19bce5494f6f66d (diff) |
Adding streaming_recall_at_precision to Tensorflow contrib metrics.
PiperOrigin-RevId: 174250716
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r-- | tensorflow/contrib/metrics/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 113 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 117 |
3 files changed, 229 insertions, 3 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index bb566f6902..302042c4dd 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -66,6 +66,7 @@ See the @{$python/contrib.metrics} guide. @@set_size @@set_union @@count +@@recall_at_precision """ from __future__ import absolute_import @@ -80,6 +81,7 @@ from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histog from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics from tensorflow.contrib.metrics.python.ops.metric_ops import count +from tensorflow.contrib.metrics.python.ops.metric_ops import recall_at_precision from tensorflow.contrib.metrics.python.ops.metric_ops import sparse_recall_at_top_k from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index fbb030348c..ca4dcef8de 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -38,6 +38,9 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.util.deprecation import deprecated +# Epsilon constant used to represent extremely small quantity. +_EPSILON = 1e-7 + def _safe_div(numerator, denominator, name): """Divides two values, returning 0 if the denominator is <= 0. @@ -1061,7 +1064,7 @@ def streaming_curve_points(labels=None, (labels, predictions, weights)): if curve != 'ROC' and curve != 'PR': raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) - kepsilon = 1e-7 # to account for floating point imprecisions + kepsilon = _EPSILON # to account for floating point imprecisions thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] @@ -1654,7 +1657,7 @@ def streaming_false_positive_rate_at_thresholds(predictions, predictions, labels, thresholds, weights, includes=('fp', 'tn')) # Avoid division by zero. - epsilon = 1e-7 + epsilon = _EPSILON def compute_fpr(fp, tn, name): return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) @@ -1725,7 +1728,7 @@ def streaming_false_negative_rate_at_thresholds(predictions, predictions, labels, thresholds, weights, includes=('fn', 'tp')) # Avoid division by zero. - epsilon = 1e-7 + epsilon = _EPSILON def compute_fnr(fn, tp, name): return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) @@ -2153,6 +2156,109 @@ def sparse_recall_at_top_k(labels, name=name_scope) +def _compute_recall_at_precision(tp, fp, fn, precision, name): + """Helper function to compute recall at a given `precision`. + + Args: + tp: The number of true positives. + fp: The number of false positives. + fn: The number of false negatives. + precision: The precision for which the recall will be calculated. + name: An optional variable_scope name. + + Returns: + The recall at a the given `precision`. + """ + precisions = math_ops.div(tp, tp + fp + _EPSILON) + tf_index = math_ops.argmin( + math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + + # Now, we have the implicit threshold, so compute the recall: + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) + + +def recall_at_precision(labels, + predictions, + precision, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes `recall` at `precision`. + + The `recall_at_precision` function creates four local variables, + `tp` (true positives), `fp` (false positives) and `fn` (false negatives) + that are used to compute the `recall` at the given `precision` value. The + threshold for the given `precision` value is computed and used to evaluate the + corresponding `recall`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `recall`. `update_op` increments the `tp`, `fp` and `fn` counts with the + weight of each case found in the `predictions` and `labels`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + precision: A scalar value in range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use for matching the given + `precision`. + metrics_collections: An optional list of collections that `recall` + should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + recall: A scalar `Tensor` representing the recall at the given + `precision` value. + update_op: An operation that increments the `tp`, `fp` and `fn` + variables appropriately and whose value matches `recall`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, if + `weights` is not `None` and its shape doesn't match `predictions`, or if + `precision` is not between 0 and 1, or if either `metrics_collections` + or `updates_collections` are not a list or tuple. + + """ + if not 0 <= precision <= 1: + raise ValueError('`precision` must be in the range [0, 1].') + + with variable_scope.variable_scope(name, 'recall_at_precision', + (predictions, labels, weights)): + thresholds = [ + i * 1.0 / (num_thresholds - 1) for i in range(1, num_thresholds - 1) + ] + thresholds = [0.0 - _EPSILON] + thresholds + [1.0 + _EPSILON] + + values, update_ops = _streaming_confusion_matrix_at_thresholds( + labels, predictions, thresholds, weights) + + recall = _compute_recall_at_precision(values['tp'], values['fp'], + values['fn'], precision, 'value') + update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'], + update_ops['fn'], precision, + 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, recall) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return recall, update_op + + def streaming_sparse_average_precision_at_k(predictions, labels, k, @@ -3168,6 +3274,7 @@ __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', 'count', + 'recall_at_precision', 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index ad4741b350..6a8e58b4da 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2917,6 +2917,123 @@ class StreamingFPRThresholdsTest(test.TestCase): self.assertAlmostEqual(expected_fpr, fpr.eval(), 2) +class RecallAtPrecisionTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.recall_at_precision( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + precision=0.7) + _assert_metric_variables(self, ('recall_at_precision/true_positives:0', + 'recall_at_precision/false_negatives:0', + 'recall_at_precision/false_positives:0', + 'recall_at_precision/true_negatives:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.recall_at_precision( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + precision=0.7, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.recall_at_precision( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + precision=0.7, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=0.7) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_recall = recall.eval() + for _ in range(10): + self.assertAlmostEqual(initial_recall, recall.eval(), 5) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=1.0) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertEqual(1, sess.run(update_op)) + self.assertEqual(1, recall.eval()) + + def testSomeCorrectHighPrecision(self): + predictions_values = [1, .9, .8, .7, .6, .5, .4, .3] + labels_values = [1, 1, 1, 1, 0, 0, 0, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=0.8) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.8, sess.run(update_op)) + self.assertAlmostEqual(0.8, recall.eval()) + + def testSomeCorrectLowPrecision(self): + predictions_values = [1, .9, .8, .7, .6, .5, .4, .3, .2, .1] + labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=0.4) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + target_recall = 2.0 / 3.0 + self.assertAlmostEqual(target_recall, sess.run(update_op)) + self.assertAlmostEqual(target_recall, recall.eval()) + + def testWeighted(self): + predictions_values = [1, .9, .8, .7, .6] + labels_values = [1, 1, 0, 0, 1] + weights_values = [1, 1, 3, 4, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + weights = constant_op.constant(weights_values) + recall, update_op = metrics.recall_at_precision( + predictions, labels, weights=weights, precision=0.4) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + target_recall = 2.0 / 3.0 + self.assertAlmostEqual(target_recall, sess.run(update_op)) + self.assertAlmostEqual(target_recall, recall.eval()) + + class StreamingFNRThresholdsTest(test.TestCase): def setUp(self): |