diff options
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops_test.py')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 262 |
1 files changed, 28 insertions, 234 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 5d0463e1f7..6a8e58b4da 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1708,34 +1708,6 @@ class StreamingCurvePointsTest(test.TestCase): [[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]]) -def _np_auc(predictions, labels, weights=None): - """Computes the AUC explicitly using Numpy. - - Args: - predictions: an ndarray with shape [N]. - labels: an ndarray with shape [N]. - weights: an ndarray with shape [N]. - - Returns: - the area under the ROC curve. - """ - if weights is None: - weights = np.ones(np.size(predictions)) - is_positive = labels > 0 - num_positives = np.sum(weights[is_positive]) - num_negatives = np.sum(weights[~is_positive]) - - # Sort descending: - inds = np.argsort(-predictions) - - sorted_labels = labels[inds] - sorted_weights = weights[inds] - is_positive = sorted_labels > 0 - - tp = np.cumsum(sorted_weights * is_positive) / num_positives - return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives - - class StreamingAUCTest(test.TestCase): def setUp(self): @@ -1924,6 +1896,33 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(1, auc.eval(), 6) + def np_auc(self, predictions, labels, weights): + """Computes the AUC explicitly using Numpy. + + Args: + predictions: an ndarray with shape [N]. + labels: an ndarray with shape [N]. + weights: an ndarray with shape [N]. + + Returns: + the area under the ROC curve. + """ + if weights is None: + weights = np.ones(np.size(predictions)) + is_positive = labels > 0 + num_positives = np.sum(weights[is_positive]) + num_negatives = np.sum(weights[~is_positive]) + + # Sort descending: + inds = np.argsort(-predictions) + + sorted_labels = labels[inds] + sorted_weights = weights[inds] + is_positive = sorted_labels > 0 + + tp = np.cumsum(sorted_weights * is_positive) / num_positives + return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives + def testWithMultipleUpdates(self): num_samples = 1000 batch_size = 10 @@ -1946,7 +1945,7 @@ class StreamingAUCTest(test.TestCase): for weights in (None, np.ones(num_samples), np.random.exponential( scale=1.0, size=num_samples)): - expected_auc = _np_auc(predictions, labels, weights) + expected_auc = self.np_auc(predictions, labels, weights) with self.test_session() as sess: enqueue_ops = [[] for i in range(num_batches)] @@ -1975,211 +1974,6 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(expected_auc, auc.eval(), 2) -class StreamingDynamicAUCTest(test.TestCase): - - def setUp(self): - super(StreamingDynamicAUCTest, self).setUp() - np.random.seed(1) - ops.reset_default_graph() - - def testUnknownCurve(self): - with self.assertRaisesRegexp( - ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'): - metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)), - predictions=array_ops.ones((10, 1)), - curve='TEST_CURVE') - - def testVars(self): - metrics.streaming_dynamic_auc( - labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1))) - _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0', - 'dynamic_auc/concat_labels/size:0', - 'dynamic_auc/concat_preds/array:0', - 'dynamic_auc/concat_preds/size:0']) - - def testMetricsCollection(self): - my_collection_name = '__metrics__' - auc, _ = metrics.streaming_dynamic_auc( - labels=array_ops.ones((10, 1)), - predictions=array_ops.ones((10, 1)), - metrics_collections=[my_collection_name]) - self.assertEqual(ops.get_collection(my_collection_name), [auc]) - - def testUpdatesCollection(self): - my_collection_name = '__updates__' - _, update_op = metrics.streaming_dynamic_auc( - labels=array_ops.ones((10, 1)), - predictions=array_ops.ones((10, 1)), - updates_collections=[my_collection_name]) - self.assertEqual(ops.get_collection(my_collection_name), [update_op]) - - def testValueTensorIsIdempotent(self): - predictions = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) - labels = random_ops.random_uniform( - (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) - auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - # Run several updates. - for _ in xrange(10): - sess.run(update_op) - # Then verify idempotency. - initial_auc = auc.eval() - for _ in xrange(10): - self.assertAlmostEqual(initial_auc, auc.eval(), 5) - - def testAllLabelsOnes(self): - with self.test_session() as sess: - predictions = constant_op.constant([1., 1., 1.]) - labels = constant_op.constant([1, 1, 1]) - auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertEqual(0, auc.eval()) - - def testAllLabelsZeros(self): - with self.test_session() as sess: - predictions = constant_op.constant([1., 1., 1.]) - labels = constant_op.constant([0, 0, 0]) - auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertEqual(0, auc.eval()) - - def testNonZeroOnePredictions(self): - with self.test_session() as sess: - predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5], - dtype=dtypes_lib.float32) - labels = constant_op.constant([1, 0, 1, 0]) - auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertAlmostEqual(auc.eval(), 1.0) - - def testAllCorrect(self): - inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: - predictions = constant_op.constant(inputs) - labels = constant_op.constant(inputs) - auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertEqual(1, auc.eval()) - - def testSomeCorrect(self): - with self.test_session() as sess: - predictions = constant_op.constant([1, 0, 1, 0]) - labels = constant_op.constant([0, 1, 1, 0]) - auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertAlmostEqual(0.5, auc.eval()) - - def testAllIncorrect(self): - inputs = np.random.randint(0, 2, size=(100, 1)) - with self.test_session() as sess: - predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) - labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) - auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertAlmostEqual(0, auc.eval()) - - def testExceptionOnIncompatibleShapes(self): - with self.test_session() as sess: - predictions = array_ops.ones([5]) - labels = array_ops.zeros([6]) - with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'): - _, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - - def testExceptionOnGreaterThanOneLabel(self): - with self.test_session() as sess: - predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) - labels = constant_op.constant([2, 1, 0]) - _, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - with self.assertRaisesRegexp( - errors_impl.InvalidArgumentError, - '.*labels must be 0 or 1, at least one is >1.*'): - sess.run(update_op) - - def testExceptionOnNegativeLabel(self): - with self.test_session() as sess: - predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) - labels = constant_op.constant([1, 0, -1]) - _, update_op = metrics.streaming_dynamic_auc(labels, predictions) - sess.run(variables.local_variables_initializer()) - with self.assertRaisesRegexp( - errors_impl.InvalidArgumentError, - '.*labels must be 0 or 1, at least one is <0.*'): - sess.run(update_op) - - def testWithMultipleUpdates(self): - batch_size = 10 - num_batches = 100 - labels = np.array([]) - predictions = np.array([]) - tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32), - collections=[ops.GraphKeys.LOCAL_VARIABLES], - dtype=dtypes_lib.int32) - tf_predictions = variables.Variable( - array_ops.ones(batch_size), - collections=[ops.GraphKeys.LOCAL_VARIABLES], - dtype=dtypes_lib.float32) - auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions) - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - for _ in xrange(num_batches): - new_labels = np.random.randint(0, 2, size=batch_size) - noise = np.random.normal(0.0, scale=0.2, size=batch_size) - new_predictions = 0.4 + 0.2 * new_labels + noise - labels = np.concatenate([labels, new_labels]) - predictions = np.concatenate([predictions, new_predictions]) - sess.run(tf_labels.assign(new_labels)) - sess.run(tf_predictions.assign(new_predictions)) - sess.run(update_op) - expected_auc = _np_auc(predictions, labels) - self.assertAlmostEqual(expected_auc, auc.eval()) - - def testAUCPRReverseIncreasingPredictions(self): - with self.test_session() as sess: - predictions = constant_op.constant( - [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32) - labels = constant_op.constant([0, 0, 1, 1]) - auc, update_op = metrics.streaming_dynamic_auc( - labels, predictions, curve='PR') - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5) - - def testAUCPRJumbledPredictions(self): - with self.test_session() as sess: - predictions = constant_op.constant( - [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32) - labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1]) - auc, update_op = metrics.streaming_dynamic_auc( - labels, predictions, curve='PR') - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6) - - def testAUCPRPredictionsLessThanHalf(self): - with self.test_session() as sess: - predictions = constant_op.constant( - [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], - shape=(1, 7), - dtype=dtypes_lib.float32) - labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7)) - auc, update_op = metrics.streaming_dynamic_auc( - labels, predictions, curve='PR') - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) - - class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): def setUp(self): |