aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops_test.py')
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py262
1 files changed, 28 insertions, 234 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 5d0463e1f7..6a8e58b4da 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1708,34 +1708,6 @@ class StreamingCurvePointsTest(test.TestCase):
[[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
-def _np_auc(predictions, labels, weights=None):
- """Computes the AUC explicitly using Numpy.
-
- Args:
- predictions: an ndarray with shape [N].
- labels: an ndarray with shape [N].
- weights: an ndarray with shape [N].
-
- Returns:
- the area under the ROC curve.
- """
- if weights is None:
- weights = np.ones(np.size(predictions))
- is_positive = labels > 0
- num_positives = np.sum(weights[is_positive])
- num_negatives = np.sum(weights[~is_positive])
-
- # Sort descending:
- inds = np.argsort(-predictions)
-
- sorted_labels = labels[inds]
- sorted_weights = weights[inds]
- is_positive = sorted_labels > 0
-
- tp = np.cumsum(sorted_weights * is_positive) / num_positives
- return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
-
-
class StreamingAUCTest(test.TestCase):
def setUp(self):
@@ -1924,6 +1896,33 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
+ def np_auc(self, predictions, labels, weights):
+ """Computes the AUC explicitly using Numpy.
+
+ Args:
+ predictions: an ndarray with shape [N].
+ labels: an ndarray with shape [N].
+ weights: an ndarray with shape [N].
+
+ Returns:
+ the area under the ROC curve.
+ """
+ if weights is None:
+ weights = np.ones(np.size(predictions))
+ is_positive = labels > 0
+ num_positives = np.sum(weights[is_positive])
+ num_negatives = np.sum(weights[~is_positive])
+
+ # Sort descending:
+ inds = np.argsort(-predictions)
+
+ sorted_labels = labels[inds]
+ sorted_weights = weights[inds]
+ is_positive = sorted_labels > 0
+
+ tp = np.cumsum(sorted_weights * is_positive) / num_positives
+ return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
+
def testWithMultipleUpdates(self):
num_samples = 1000
batch_size = 10
@@ -1946,7 +1945,7 @@ class StreamingAUCTest(test.TestCase):
for weights in (None, np.ones(num_samples), np.random.exponential(
scale=1.0, size=num_samples)):
- expected_auc = _np_auc(predictions, labels, weights)
+ expected_auc = self.np_auc(predictions, labels, weights)
with self.test_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
@@ -1975,211 +1974,6 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval(), 2)
-class StreamingDynamicAUCTest(test.TestCase):
-
- def setUp(self):
- super(StreamingDynamicAUCTest, self).setUp()
- np.random.seed(1)
- ops.reset_default_graph()
-
- def testUnknownCurve(self):
- with self.assertRaisesRegexp(
- ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
- metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
- predictions=array_ops.ones((10, 1)),
- curve='TEST_CURVE')
-
- def testVars(self):
- metrics.streaming_dynamic_auc(
- labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
- _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
- 'dynamic_auc/concat_labels/size:0',
- 'dynamic_auc/concat_preds/array:0',
- 'dynamic_auc/concat_preds/size:0'])
-
- def testMetricsCollection(self):
- my_collection_name = '__metrics__'
- auc, _ = metrics.streaming_dynamic_auc(
- labels=array_ops.ones((10, 1)),
- predictions=array_ops.ones((10, 1)),
- metrics_collections=[my_collection_name])
- self.assertEqual(ops.get_collection(my_collection_name), [auc])
-
- def testUpdatesCollection(self):
- my_collection_name = '__updates__'
- _, update_op = metrics.streaming_dynamic_auc(
- labels=array_ops.ones((10, 1)),
- predictions=array_ops.ones((10, 1)),
- updates_collections=[my_collection_name])
- self.assertEqual(ops.get_collection(my_collection_name), [update_op])
-
- def testValueTensorIsIdempotent(self):
- predictions = random_ops.random_uniform(
- (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
- labels = random_ops.random_uniform(
- (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
- auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- # Run several updates.
- for _ in xrange(10):
- sess.run(update_op)
- # Then verify idempotency.
- initial_auc = auc.eval()
- for _ in xrange(10):
- self.assertAlmostEqual(initial_auc, auc.eval(), 5)
-
- def testAllLabelsOnes(self):
- with self.test_session() as sess:
- predictions = constant_op.constant([1., 1., 1.])
- labels = constant_op.constant([1, 1, 1])
- auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertEqual(0, auc.eval())
-
- def testAllLabelsZeros(self):
- with self.test_session() as sess:
- predictions = constant_op.constant([1., 1., 1.])
- labels = constant_op.constant([0, 0, 0])
- auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertEqual(0, auc.eval())
-
- def testNonZeroOnePredictions(self):
- with self.test_session() as sess:
- predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
- dtype=dtypes_lib.float32)
- labels = constant_op.constant([1, 0, 1, 0])
- auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertAlmostEqual(auc.eval(), 1.0)
-
- def testAllCorrect(self):
- inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
- predictions = constant_op.constant(inputs)
- labels = constant_op.constant(inputs)
- auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertEqual(1, auc.eval())
-
- def testSomeCorrect(self):
- with self.test_session() as sess:
- predictions = constant_op.constant([1, 0, 1, 0])
- labels = constant_op.constant([0, 1, 1, 0])
- auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertAlmostEqual(0.5, auc.eval())
-
- def testAllIncorrect(self):
- inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
- predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
- labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
- auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertAlmostEqual(0, auc.eval())
-
- def testExceptionOnIncompatibleShapes(self):
- with self.test_session() as sess:
- predictions = array_ops.ones([5])
- labels = array_ops.zeros([6])
- with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
- _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
-
- def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
- predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
- labels = constant_op.constant([2, 1, 0])
- _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- '.*labels must be 0 or 1, at least one is >1.*'):
- sess.run(update_op)
-
- def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
- predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
- labels = constant_op.constant([1, 0, -1])
- _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- sess.run(variables.local_variables_initializer())
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- '.*labels must be 0 or 1, at least one is <0.*'):
- sess.run(update_op)
-
- def testWithMultipleUpdates(self):
- batch_size = 10
- num_batches = 100
- labels = np.array([])
- predictions = np.array([])
- tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- dtype=dtypes_lib.int32)
- tf_predictions = variables.Variable(
- array_ops.ones(batch_size),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- dtype=dtypes_lib.float32)
- auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
- with self.test_session() as sess:
- sess.run(variables.local_variables_initializer())
- for _ in xrange(num_batches):
- new_labels = np.random.randint(0, 2, size=batch_size)
- noise = np.random.normal(0.0, scale=0.2, size=batch_size)
- new_predictions = 0.4 + 0.2 * new_labels + noise
- labels = np.concatenate([labels, new_labels])
- predictions = np.concatenate([predictions, new_predictions])
- sess.run(tf_labels.assign(new_labels))
- sess.run(tf_predictions.assign(new_predictions))
- sess.run(update_op)
- expected_auc = _np_auc(predictions, labels)
- self.assertAlmostEqual(expected_auc, auc.eval())
-
- def testAUCPRReverseIncreasingPredictions(self):
- with self.test_session() as sess:
- predictions = constant_op.constant(
- [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
- labels = constant_op.constant([0, 0, 1, 1])
- auc, update_op = metrics.streaming_dynamic_auc(
- labels, predictions, curve='PR')
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
-
- def testAUCPRJumbledPredictions(self):
- with self.test_session() as sess:
- predictions = constant_op.constant(
- [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
- labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
- auc, update_op = metrics.streaming_dynamic_auc(
- labels, predictions, curve='PR')
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
-
- def testAUCPRPredictionsLessThanHalf(self):
- with self.test_session() as sess:
- predictions = constant_op.constant(
- [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
- shape=(1, 7),
- dtype=dtypes_lib.float32)
- labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
- auc, update_op = metrics.streaming_dynamic_auc(
- labels, predictions, curve='PR')
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
- self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
-
-
class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
def setUp(self):