diff options
author | 2017-10-19 16:00:31 -0700 | |
---|---|---|
committer | 2017-10-19 16:11:36 -0700 | |
commit | bc93dcbd9f7b445c5f6f0d1c8f597324d412a76a (patch) | |
tree | 214fec46b32cc834fd6d5dabb6ca20e61c64bc7a | |
parent | f080052284a4a39113051fb1178d91365e9872a8 (diff) |
Fix precision/recall test.
Precision and Recall have as the numerator TP: true positives.
The labels generated in the test were only negative, and hence the test passed before because all updates were 0.
PiperOrigin-RevId: 172812994
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 58 |
1 files changed, 26 insertions, 32 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index cc0ad155fa..f288fceef6 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) precision, update_op = metrics.streaming_precision(predictions, labels) with self.test_session() as sess: @@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) recall, update_op = metrics.streaming_recall(predictions, labels) with self.test_session() as sess: @@ -1388,7 +1388,7 @@ class StreamingFPRTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fpr, update_op = metrics.streaming_false_positive_rate( predictions, labels) @@ -1516,7 +1516,7 @@ class StreamingFNRTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) fnr, update_op = metrics.streaming_false_negative_rate( predictions, labels) @@ -1737,7 +1737,7 @@ class StreamingAUCTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_auc(predictions, labels) with self.test_session() as sess: @@ -2009,7 +2009,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) @@ -2271,7 +2271,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, labels, @@ -2282,12 +2282,14 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): with self.test_session() as sess: sess.run(variables.local_variables_initializer()) - # Run several updates, then verify idempotency. - sess.run([prec_op, rec_op]) + # Run several updates. + for _ in range(10): + sess.run([prec_op, rec_op]) + + # Then verify idempotency. initial_prec = prec.eval() initial_rec = rec.eval() for _ in range(10): - sess.run([prec_op, rec_op]) self.assertAllClose(initial_prec, prec.eval()) self.assertAllClose(initial_rec, rec.eval()) @@ -2361,14 +2363,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2391,14 +2389,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2420,10 +2414,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, thresholds) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] sess.run(variables.local_variables_initializer()) sess.run([prec_op, rec_op]) @@ -2562,7 +2556,7 @@ class StreamingFPRThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( predictions, labels, thresholds) @@ -2794,7 +2788,7 @@ class StreamingFNRThresholdsTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds) |