aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-05 11:14:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-05 11:19:05 -0700
commitfd9c3f65870718a52c3627569aafc460febaf488 (patch)
tree0111e825b7dda0d499ba0bd152371b1a75f73597 /tensorflow/contrib/metrics
parentf841f34de31e04b3e997b0cb8f31a1286cdc7851 (diff)
Verify that predictions are in the expected range for ops that use thresholds, e.g. tf.contrib.metrics.streaming_auc.
PiperOrigin-RevId: 167604306
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 00cde08bff..9b959b43a9 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1496,6 +1496,15 @@ class StreamingAUCTest(test.TestCase):
for _ in range(10):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
+ def testPredictionsOutOfRange(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
+ labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ _, update_op = metrics.streaming_auc(predictions, labels)
+ sess.run(variables.local_variables_initializer())
+ self.assertRaises(errors_impl.InvalidArgumentError, update_op.eval)
+
def testAllCorrect(self):
self.allCorrectAsExpected('ROC')