diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-11 11:36:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-11 11:40:24 -0700 |
commit | 8111f75a83d5d31b4f23fea45488c5a2b43b5bd8 (patch) | |
tree | 2db20becf47c14fc931af9d44f922a4bb9c86b14 /tensorflow/contrib/metrics | |
parent | 68556f86ad5d2c498828f309fca429f1ad9640f8 (diff) |
Implements tf.metrics.f1_score as the maximum f1 score across different thresholds.
PiperOrigin-RevId: 204159841
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r-- | tensorflow/contrib/metrics/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/metrics/classification.py | 121 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/metrics/classification_test.py | 202 |
4 files changed, 325 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 66cb493e5c..21cd34f73f 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -31,6 +31,7 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:confusion_matrix", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:distribute", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:histogram_ops", "//tensorflow/python:init_ops", diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 5effea3596..88798d61b7 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -63,6 +63,7 @@ See the @{$python/contrib.metrics} guide. @@aggregate_metrics @@aggregate_metric_map @@confusion_matrix +@@f1_score @@set_difference @@set_intersection @@set_size diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index 26aba1cc51..e553612269 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -22,6 +22,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics_impl +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribute as distribute_lib # TODO(nsilberman): move into metrics/python/ops/ @@ -62,3 +65,121 @@ def accuracy(predictions, labels, weights=None, name=None): return math_ops.div(math_ops.reduce_sum(is_correct), math_ops.reduce_sum(num_values)) return math_ops.reduce_mean(is_correct) + + +def f1_score(labels, predictions, weights=None, num_thresholds=200, + metrics_collections=None, updates_collections=None, name=None): + """Computes the approximately best F1-score across different thresholds. + + The f1_score function applies a range of thresholds to the predictions to + convert them from [0, 1] to bool. Precision and recall are computed by + comparing them to the labels. The F1-Score is then defined as + 2 * precision * recall / (precision + recall). The best one across the + thresholds is returned. + + Disclaimer: In practice it may be desirable to choose the best threshold on + the validation set and evaluate the F1 score with this threshold on a + separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5). + + This function internally creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the pairs of recall and precision values for a linearly spaced set of + thresholds from which the best f1-score is derived. + + This value is ultimately returned as `f1-score`, an idempotent operation that + computes the F1-score (computed using the aforementioned variables). The + `num_thresholds` variable controls the degree of discretization with larger + numbers of thresholds more closely approximating the true best F1-score. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the F1-score. + + Example usage with a custom estimator: + def model_fn(features, labels, mode): + predictions = make_predictions(features) + loss = make_loss(predictions, labels) + train_op = tf.contrib.training.create_train_op( + total_loss=loss, + optimizer='Adam') + eval_metric_ops = {'f1': f1_score(labels, predictions)} + return tf.estimator.EstimatorSpec( + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + export_outputs=export_outputs) + estimator = tf.estimator.Estimator(model_fn=model_fn) + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use when discretizing the roc + curve. + metrics_collections: An optional list of collections that `f1_score` should + be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + f1_score: A scalar `Tensor` representing the current best f1-score across + different thresholds. + update_op: An operation that increments the `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` variables + appropriately and whose value matches the `f1_score`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'f1', (labels, predictions, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, labels=labels, weights=weights) + # To account for floating point imprecisions / avoid division by zero. + epsilon = 1e-7 + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon] + + # Confusion matrix. + values, update_ops = metrics_impl._confusion_matrix_at_thresholds( # pylint: disable=protected-access + labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn')) + + # Compute precision and recall at various thresholds. + def compute_best_f1_score(tp, fp, fn, name): + precision_at_t = math_ops.div(tp, epsilon + tp + fp, + name='precision_' + name) + recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) + # Compute F1 score. + f1_at_thresholds = ( + 2.0 * precision_at_t * recall_at_t / + (precision_at_t + recall_at_t + epsilon)) + return math_ops.reduce_max(f1_at_thresholds) + + def f1_across_towers(_, values): + best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'], + fn=values['fn'], name='value') + if metrics_collections: + ops.add_to_collections(metrics_collections, best_f1) + return best_f1 + + best_f1 = distribute_lib.get_tower_context().merge_call( + f1_across_towers, values) + + update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], + fn=update_ops['fn'], name='update') + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return best_f1, update_op diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index fa0f12d029..3d0b81c1be 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -18,9 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.metrics.python.metrics import classification +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -108,5 +115,200 @@ class ClassificationTest(test.TestCase): self.assertEqual(result, 0.5) +class F1ScoreTest(test.TestCase): + + def setUp(self): + super(F1ScoreTest, self).setUp() + np.random.seed(1) + + def testVars(self): + classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3) + expected = {'f1/true_positives:0', 'f1/false_positives:0', + 'f1/false_negatives:0'} + self.assertEquals( + expected, set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), + set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + f1, _ = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, f1_op = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes.int64, seed=2) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run([f1_op]) + + # Then verify idempotency. + initial_f1 = f1.eval() + for _ in range(10): + self.assertAllClose(initial_f1, f1.eval()) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(inputs) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertEqual(1, f1.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(10000, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes.float32) + labels = array_ops.zeros([4]) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(0.0, f1.eval(), places=5) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [-0.01, 0.5, 1.01] + + expected_max_f1 = -1.0 + for threshold in thresholds: + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] >= threshold: + if labels[i] == 1: + tp += 1 + else: + fp += 1 + else: + if labels[i] == 1: + fn += 1 + else: + tn += 1 + epsilon = 1e-7 + expected_prec = tp / (epsilon + tp + fp) + expected_rec = tp / (epsilon + tp + fn) + expected_f1 = (2 * expected_prec * expected_rec / + (epsilon + expected_prec + expected_rec)) + if expected_f1 > expected_max_f1: + expected_max_f1 = expected_f1 + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + tf_predictions, tf_labels = (dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size) + .make_one_shot_iterator() + .get_next()) + f1, f1_op = classification.f1_score(tf_labels, tf_predictions, + num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in range(num_batches): + sess.run([f1_op]) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_max_f1, f1.eval(), 2) + + if __name__ == '__main__': test.main() |