diff options
Diffstat (limited to 'tensorflow/contrib/metrics/python/metrics/classification_test.py')
-rw-r--r-- | tensorflow/contrib/metrics/python/metrics/classification_test.py | 202 |
1 files changed, 202 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index fa0f12d029..3d0b81c1be 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -18,9 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.metrics.python.metrics import classification +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -108,5 +115,200 @@ class ClassificationTest(test.TestCase): self.assertEqual(result, 0.5) +class F1ScoreTest(test.TestCase): + + def setUp(self): + super(F1ScoreTest, self).setUp() + np.random.seed(1) + + def testVars(self): + classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3) + expected = {'f1/true_positives:0', 'f1/false_positives:0', + 'f1/false_negatives:0'} + self.assertEquals( + expected, set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), + set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + f1, _ = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, f1_op = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes.int64, seed=2) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run([f1_op]) + + # Then verify idempotency. + initial_f1 = f1.eval() + for _ in range(10): + self.assertAllClose(initial_f1, f1.eval()) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(inputs) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertEqual(1, f1.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(10000, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes.float32) + labels = array_ops.zeros([4]) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(0.0, f1.eval(), places=5) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [-0.01, 0.5, 1.01] + + expected_max_f1 = -1.0 + for threshold in thresholds: + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] >= threshold: + if labels[i] == 1: + tp += 1 + else: + fp += 1 + else: + if labels[i] == 1: + fn += 1 + else: + tn += 1 + epsilon = 1e-7 + expected_prec = tp / (epsilon + tp + fp) + expected_rec = tp / (epsilon + tp + fn) + expected_f1 = (2 * expected_prec * expected_rec / + (epsilon + expected_prec + expected_rec)) + if expected_f1 > expected_max_f1: + expected_max_f1 = expected_f1 + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + tf_predictions, tf_labels = (dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size) + .make_one_shot_iterator() + .get_next()) + f1, f1_op = classification.f1_score(tf_labels, tf_predictions, + num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in range(num_batches): + sess.run([f1_op]) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_max_f1, f1.eval(), 2) + + if __name__ == '__main__': test.main() |