diff options
Diffstat (limited to 'tensorflow/contrib/metrics/python/ops/metric_ops_test.py')
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index f05ae394e6..89aa29f711 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -6660,5 +6660,213 @@ class CountTest(test.TestCase): self.assertAlmostEqual(4.1, result.eval(), 5) +class CohenKappaTest(test.TestCase): + + def _confusion_matrix_to_samples(self, confusion_matrix): + x, y = confusion_matrix.shape + pairs = [] + for label in range(x): + for feature in range(y): + pairs += [label, feature] * confusion_matrix[label, feature] + pairs = np.array(pairs).reshape((-1, 2)) + return pairs[:, 0], pairs[:, 1] + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.cohen_kappa( + predictions_idx=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_classes=2) + _assert_metric_variables(self, ( + 'cohen_kappa/po:0', + 'cohen_kappa/pe_row:0', + 'cohen_kappa/pe_col:0',)) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + kappa, _ = metrics.cohen_kappa( + predictions_idx=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_classes=2, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [kappa]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.cohen_kappa( + predictions_idx=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_classes=2, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_kappa = kappa.eval() + for _ in range(10): + self.assertAlmostEqual(initial_kappa, kappa.eval(), 5) + + def testBasic(self): + confusion_matrix = np.array([ + [9, 3, 1], + [4, 8, 2], + [2, 1, 6]]) + # overall total = 36 + # po = [9, 8, 6], sum(po) = 23 + # pe_row = [15, 12, 9], pe_col = [13, 14, 9], so pe = [5.42, 4.67, 2.25] + # finally, kappa = (sum(po) - sum(pe)) / (N - sum(pe)) + # = (23 - 12.34) / (36 - 12.34) + # = 0.45 + # see: http://psych.unl.edu/psycrs/handcomp/hckappa.PDF + expect = 0.45 + labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) + + dtypes = [dtypes_lib.int16, dtypes_lib.int32, dtypes_lib.int64] + shapes = [(len(labels,)), # 1-dim + (len(labels), 1)] # 2-dim + weights = [None, np.ones_like(labels)] + + for dtype in dtypes: + for shape in shapes: + for weight in weights: + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + np.reshape(predictions, shape), dtype=dtype) + labels_tensor = constant_op.constant( + np.reshape(labels, shape), dtype=dtype) + kappa, update_op = metrics.cohen_kappa( + labels_tensor, predictions_tensor, 3, weights=weight) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 2) + self.assertAlmostEqual(expect, kappa.eval(), 2) + + def testAllCorrect(self): + inputs = np.arange(0, 100) % 4 + # confusion matrix + # [[25, 0, 0], + # [0, 25, 0], + # [0, 0, 25]] + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs) + expect = 1.0 + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 5) + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testAllIncorrect(self): + labels = np.arange(0, 100) % 4 + predictions = (labels + 1) % 4 + # confusion matrix + # [[0, 25, 0], + # [0, 0, 25], + # [25, 0, 0]] + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions) + expect = -0.333333333333 + + with self.test_session() as sess: + predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 5) + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testWeighted(self): + confusion_matrix = np.array([ + [9, 3, 1], + [4, 8, 2], + [2, 1, 6]]) + labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) + num_samples = np.sum(confusion_matrix, dtype=np.int32) + weights = (np.arange(0, num_samples) % 5) / 5.0 + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( + # labels, predictions, sample_weight=weights) + expect = 0.453466583385 + + with self.test_session() as sess: + predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 4, + weights=weights) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 5) + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testWithMultipleUpdates(self): + confusion_matrix = np.array([ + [90, 30, 10, 20], + [40, 80, 20, 30], + [20, 10, 60, 35], + [15, 25, 30, 25]]) + labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) + num_samples = np.sum(confusion_matrix, dtype=np.int32) + weights = (np.arange(0, num_samples) % 5) / 5.0 + num_classes = confusion_matrix.shape[0] + + batch_size = num_samples // 10 + predictions_t = array_ops.placeholder(dtypes_lib.float32, + shape=(batch_size,)) + labels_t = array_ops.placeholder(dtypes_lib.int32, + shape=(batch_size,)) + weights_t = array_ops.placeholder(dtypes_lib.float32, + shape=(batch_size,)) + kappa, update_op = metrics.cohen_kappa( + labels_t, predictions_t, num_classes, weights=weights_t) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + for idx in range(0, num_samples, batch_size): + batch_start, batch_end = idx, idx + batch_size + sess.run(update_op, + feed_dict={labels_t: labels[batch_start:batch_end], + predictions_t: predictions[batch_start:batch_end], + weights_t: weights[batch_start:batch_end]}) + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( + # labels_np, predictions_np, sample_weight=weights_np) + expect = 0.289965397924 + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testInvalidNumClasses(self): + predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1)) + labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1)) + with self.assertRaisesRegexp(ValueError, 'num_classes'): + metrics.cohen_kappa(labels, predictions, 1) + + def testInvalidDimension(self): + predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1)) + invalid_labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 2)) + with self.assertRaises(ValueError): + metrics.cohen_kappa(invalid_labels, predictions, 3) + + invalid_predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 2)) + labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1)) + with self.assertRaises(ValueError): + metrics.cohen_kappa(labels, invalid_predictions, 3) + + if __name__ == '__main__': test.main() |