diff options
author | Patrick Nguyen <drpng@google.com> | 2017-12-28 16:04:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-28 16:08:58 -0800 |
commit | 20765b3e1ae3b718699592c98aa9805cb874b6d1 (patch) | |
tree | b429a74cd0046404644f34cc8fe6ff2cab78bb85 /tensorflow/contrib/metrics | |
parent | 2e2715baa84720f786b38d1f9cb6887399020d6f (diff) |
Merge changes from github.
PiperOrigin-RevId: 180301735
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r-- | tensorflow/contrib/metrics/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops.py | 124 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 208 |
3 files changed, 334 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 27dad5379a..d3dce46bfb 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -66,6 +66,7 @@ See the @{$python/contrib.metrics} guide. @@set_intersection @@set_size @@set_union +@@cohen_kappa @@count @@precision_recall_at_equal_thresholds @@recall_at_precision @@ -82,6 +83,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics +from tensorflow.contrib.metrics.python.ops.metric_ops import cohen_kappa from tensorflow.contrib.metrics.python.ops.metric_ops import count from tensorflow.contrib.metrics.python.ops.metric_ops import precision_recall_at_equal_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import recall_at_precision diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 2f27985634..c3de1c4c62 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -24,10 +24,12 @@ from __future__ import print_function import collections as collections_lib +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -3297,9 +3299,131 @@ def count(values, return count_, update_op +def cohen_kappa(labels, predictions_idx, num_classes, weights=None, + metrics_collections=None, updates_collections=None, name=None): + """Calculates Cohen's kappa. + + [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic + that measures inter-annotator agreement. + + The `cohen_kappa` function calculates the confusion matrix, and creates three + local variables to compute the Cohen's kappa: `po`, `pe_row`, and `pe_col`, + which refer to the diagonal part, rows and columns totals of the confusion + matrix, respectively. This value is ultimately returned as `kappa`, an + idempotent operation that is calculated by + + pe = (pe_row * pe_col) / N + k = (sum(po) - sum(pe)) / (N - sum(pe)) + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `kappa`. `update_op` weights each prediction by the corresponding value in + `weights`. + + Class labels are expected to start at 0. E.g., if `num_classes` + was three, then the possible labels would be [0, 1, 2]. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + NOTE: Equivalent to `sklearn.metrics.cohen_kappa_score`, but the method + doesn't support weighted matrix yet. + + Args: + labels: 1-D `Tensor` of real labels for the classification task. Must be + one of the following types: int16, int32, int64. + predictions_idx: 1-D `Tensor` of predicted class indices for a given + classification. Must have the same type as `labels`. + num_classes: The possible number of labels. + weights: Optional `Tensor` whose shape matches `predictions`. + metrics_collections: An optional list of collections that `kappa` should + be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + kappa: Scalar float `Tensor` representing the current Cohen's kappa. + update_op: `Operation` that increments `po`, `pe_row` and `pe_col` + variables appropriately and whose value matches `kappa`. + + Raises: + ValueError: If `num_classes` is less than 2, or `predictions` and `labels` + have mismatched shapes, or if `weights` is not `None` and its shape + doesn't match `predictions`, or if either `metrics_collections` or + `updates_collections` are not a list or tuple. + RuntimeError: If eager execution is enabled. + """ + if context.in_eager_mode(): + raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported' + 'when eager execution is enabled.') + if num_classes < 2: + raise ValueError('`num_classes` must be >= 2.' + 'Found: {}'.format(num_classes)) + with variable_scope.variable_scope(name, 'cohen_kappa', + (labels, predictions_idx, weights)): + # Convert 2-dim (num, 1) to 1-dim (num,) + labels.get_shape().with_rank_at_most(2) + if labels.get_shape().ndims == 2: + labels = array_ops.squeeze(labels, axis=[-1]) + predictions_idx, labels, weights = ( + metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions_idx, labels=labels, weights=weights)) + predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape()) + + stat_dtype = (dtypes.int64 + if weights is None or weights.dtype.is_integer + else dtypes.float32) + po = metrics_impl.metric_variable( + (num_classes,), stat_dtype, name='po') + pe_row = metrics_impl.metric_variable( + (num_classes,), stat_dtype, name='pe_row') + pe_col = metrics_impl.metric_variable( + (num_classes,), stat_dtype, name='pe_col') + + # Table of the counts of agreement: + counts_in_table = confusion_matrix.confusion_matrix( + labels, predictions_idx, + num_classes=num_classes, weights=weights, + dtype=stat_dtype, name="counts_in_table") + + po_t = array_ops.diag_part(counts_in_table) + pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0) + pe_col_t = math_ops.reduce_sum(counts_in_table, axis=1) + update_po = state_ops.assign_add(po, po_t) + update_pe_row = state_ops.assign_add(pe_row, pe_row_t) + update_pe_col = state_ops.assign_add(pe_col, pe_col_t) + + def _calculate_k(po, pe_row, pe_col, name): + po_sum = math_ops.reduce_sum(po) + total = math_ops.reduce_sum(pe_row) + pe_sum = math_ops.reduce_sum( + metrics_impl._safe_div( # pylint: disable=protected-access + pe_row * pe_col, total, None)) + po_sum, pe_sum, total = (math_ops.to_double(po_sum), + math_ops.to_double(pe_sum), + math_ops.to_double(total)) + # kappa = (po - pe) / (N - pe) + k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access + po_sum - pe_sum, total - pe_sum, name=name) + return k + + kappa = _calculate_k(po, pe_row, pe_col, name='value') + update_op = _calculate_k(update_po, update_pe_row, update_pe_col, + name='update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, kappa) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return kappa, update_op + + __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', + 'cohen_kappa', 'count', 'precision_recall_at_equal_thresholds', 'recall_at_precision', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index f05ae394e6..89aa29f711 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -6660,5 +6660,213 @@ class CountTest(test.TestCase): self.assertAlmostEqual(4.1, result.eval(), 5) +class CohenKappaTest(test.TestCase): + + def _confusion_matrix_to_samples(self, confusion_matrix): + x, y = confusion_matrix.shape + pairs = [] + for label in range(x): + for feature in range(y): + pairs += [label, feature] * confusion_matrix[label, feature] + pairs = np.array(pairs).reshape((-1, 2)) + return pairs[:, 0], pairs[:, 1] + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.cohen_kappa( + predictions_idx=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_classes=2) + _assert_metric_variables(self, ( + 'cohen_kappa/po:0', + 'cohen_kappa/pe_row:0', + 'cohen_kappa/pe_col:0',)) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + kappa, _ = metrics.cohen_kappa( + predictions_idx=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_classes=2, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [kappa]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.cohen_kappa( + predictions_idx=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_classes=2, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_kappa = kappa.eval() + for _ in range(10): + self.assertAlmostEqual(initial_kappa, kappa.eval(), 5) + + def testBasic(self): + confusion_matrix = np.array([ + [9, 3, 1], + [4, 8, 2], + [2, 1, 6]]) + # overall total = 36 + # po = [9, 8, 6], sum(po) = 23 + # pe_row = [15, 12, 9], pe_col = [13, 14, 9], so pe = [5.42, 4.67, 2.25] + # finally, kappa = (sum(po) - sum(pe)) / (N - sum(pe)) + # = (23 - 12.34) / (36 - 12.34) + # = 0.45 + # see: http://psych.unl.edu/psycrs/handcomp/hckappa.PDF + expect = 0.45 + labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) + + dtypes = [dtypes_lib.int16, dtypes_lib.int32, dtypes_lib.int64] + shapes = [(len(labels,)), # 1-dim + (len(labels), 1)] # 2-dim + weights = [None, np.ones_like(labels)] + + for dtype in dtypes: + for shape in shapes: + for weight in weights: + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + np.reshape(predictions, shape), dtype=dtype) + labels_tensor = constant_op.constant( + np.reshape(labels, shape), dtype=dtype) + kappa, update_op = metrics.cohen_kappa( + labels_tensor, predictions_tensor, 3, weights=weight) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 2) + self.assertAlmostEqual(expect, kappa.eval(), 2) + + def testAllCorrect(self): + inputs = np.arange(0, 100) % 4 + # confusion matrix + # [[25, 0, 0], + # [0, 25, 0], + # [0, 0, 25]] + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs) + expect = 1.0 + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 5) + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testAllIncorrect(self): + labels = np.arange(0, 100) % 4 + predictions = (labels + 1) % 4 + # confusion matrix + # [[0, 25, 0], + # [0, 0, 25], + # [25, 0, 0]] + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions) + expect = -0.333333333333 + + with self.test_session() as sess: + predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 4) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 5) + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testWeighted(self): + confusion_matrix = np.array([ + [9, 3, 1], + [4, 8, 2], + [2, 1, 6]]) + labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) + num_samples = np.sum(confusion_matrix, dtype=np.int32) + weights = (np.arange(0, num_samples) % 5) / 5.0 + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( + # labels, predictions, sample_weight=weights) + expect = 0.453466583385 + + with self.test_session() as sess: + predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels) + kappa, update_op = metrics.cohen_kappa(labels, predictions, 4, + weights=weights) + + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expect, sess.run(update_op), 5) + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testWithMultipleUpdates(self): + confusion_matrix = np.array([ + [90, 30, 10, 20], + [40, 80, 20, 30], + [20, 10, 60, 35], + [15, 25, 30, 25]]) + labels, predictions = self._confusion_matrix_to_samples(confusion_matrix) + num_samples = np.sum(confusion_matrix, dtype=np.int32) + weights = (np.arange(0, num_samples) % 5) / 5.0 + num_classes = confusion_matrix.shape[0] + + batch_size = num_samples // 10 + predictions_t = array_ops.placeholder(dtypes_lib.float32, + shape=(batch_size,)) + labels_t = array_ops.placeholder(dtypes_lib.int32, + shape=(batch_size,)) + weights_t = array_ops.placeholder(dtypes_lib.float32, + shape=(batch_size,)) + kappa, update_op = metrics.cohen_kappa( + labels_t, predictions_t, num_classes, weights=weights_t) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + for idx in range(0, num_samples, batch_size): + batch_start, batch_end = idx, idx + batch_size + sess.run(update_op, + feed_dict={labels_t: labels[batch_start:batch_end], + predictions_t: predictions[batch_start:batch_end], + weights_t: weights[batch_start:batch_end]}) + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( + # labels_np, predictions_np, sample_weight=weights_np) + expect = 0.289965397924 + self.assertAlmostEqual(expect, kappa.eval(), 5) + + def testInvalidNumClasses(self): + predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1)) + labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1)) + with self.assertRaisesRegexp(ValueError, 'num_classes'): + metrics.cohen_kappa(labels, predictions, 1) + + def testInvalidDimension(self): + predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1)) + invalid_labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 2)) + with self.assertRaises(ValueError): + metrics.cohen_kappa(invalid_labels, predictions, 3) + + invalid_predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 2)) + labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1)) + with self.assertRaises(ValueError): + metrics.cohen_kappa(labels, invalid_predictions, 3) + + if __name__ == '__main__': test.main() |