aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-19 16:00:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-19 16:11:36 -0700
commitbc93dcbd9f7b445c5f6f0d1c8f597324d412a76a (patch)
tree214fec46b32cc834fd6d5dabb6ca20e61c64bc7a /tensorflow/contrib/metrics
parentf080052284a4a39113051fb1178d91365e9872a8 (diff)
Fix precision/recall test.
Precision and Recall have as the numerator TP: true positives. The labels generated in the test were only negative, and hence the test passed before because all updates were 0. PiperOrigin-RevId: 172812994
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py58
1 files changed, 26 insertions, 32 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index cc0ad155fa..f288fceef6 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
precision, update_op = metrics.streaming_precision(predictions, labels)
with self.test_session() as sess:
@@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
recall, update_op = metrics.streaming_recall(predictions, labels)
with self.test_session() as sess:
@@ -1388,7 +1388,7 @@ class StreamingFPRTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels)
@@ -1516,7 +1516,7 @@ class StreamingFNRTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels)
@@ -1737,7 +1737,7 @@ class StreamingAUCTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_auc(predictions, labels)
with self.test_session() as sess:
@@ -2009,7 +2009,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
@@ -2271,7 +2271,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
labels,
@@ -2282,12 +2282,14 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
- # Run several updates, then verify idempotency.
- sess.run([prec_op, rec_op])
+ # Run several updates.
+ for _ in range(10):
+ sess.run([prec_op, rec_op])
+
+ # Then verify idempotency.
initial_prec = prec.eval()
initial_rec = rec.eval()
for _ in range(10):
- sess.run([prec_op, rec_op])
self.assertAllClose(initial_prec, prec.eval())
self.assertAllClose(initial_rec, rec.eval())
@@ -2361,14 +2363,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds, weights=weights)
- [prec_low, prec_high] = array_ops.split(
- value=prec, num_or_size_splits=2, axis=0)
- prec_low = array_ops.reshape(prec_low, shape=())
- prec_high = array_ops.reshape(prec_high, shape=())
- [rec_low, rec_high] = array_ops.split(
- value=rec, num_or_size_splits=2, axis=0)
- rec_low = array_ops.reshape(rec_low, shape=())
- rec_high = array_ops.reshape(rec_high, shape=())
+ prec_low = prec[0]
+ prec_high = prec[1]
+ rec_low = rec[0]
+ rec_high = rec[1]
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2391,14 +2389,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds, weights=weights)
- [prec_low, prec_high] = array_ops.split(
- value=prec, num_or_size_splits=2, axis=0)
- prec_low = array_ops.reshape(prec_low, shape=())
- prec_high = array_ops.reshape(prec_high, shape=())
- [rec_low, rec_high] = array_ops.split(
- value=rec, num_or_size_splits=2, axis=0)
- rec_low = array_ops.reshape(rec_low, shape=())
- rec_high = array_ops.reshape(rec_high, shape=())
+ prec_low = prec[0]
+ prec_high = prec[1]
+ rec_low = rec[0]
+ rec_high = rec[1]
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2420,10 +2414,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
thresholds)
- [prec_low, prec_high] = array_ops.split(
- value=prec, num_or_size_splits=2, axis=0)
- [rec_low, rec_high] = array_ops.split(
- value=rec, num_or_size_splits=2, axis=0)
+ prec_low = prec[0]
+ prec_high = prec[1]
+ rec_low = rec[0]
+ rec_high = rec[1]
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2562,7 +2556,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
predictions, labels, thresholds)
@@ -2794,7 +2788,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
+ (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
predictions, labels, thresholds)