aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2017-12-28 16:04:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-28 16:08:58 -0800
commit20765b3e1ae3b718699592c98aa9805cb874b6d1 (patch)
treeb429a74cd0046404644f34cc8fe6ff2cab78bb85 /tensorflow/contrib/metrics
parent2e2715baa84720f786b38d1f9cb6887399020d6f (diff)
Merge changes from github.
PiperOrigin-RevId: 180301735
Diffstat (limited to 'tensorflow/contrib/metrics')
-rw-r--r--tensorflow/contrib/metrics/__init__.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py124
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py208
3 files changed, 334 insertions, 0 deletions
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 27dad5379a..d3dce46bfb 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -66,6 +66,7 @@ See the @{$python/contrib.metrics} guide.
@@set_intersection
@@set_size
@@set_union
+@@cohen_kappa
@@count
@@precision_recall_at_equal_thresholds
@@recall_at_precision
@@ -82,6 +83,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion
from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map
from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics
+from tensorflow.contrib.metrics.python.ops.metric_ops import cohen_kappa
from tensorflow.contrib.metrics.python.ops.metric_ops import count
from tensorflow.contrib.metrics.python.ops.metric_ops import precision_recall_at_equal_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import recall_at_precision
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 2f27985634..c3de1c4c62 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -24,10 +24,12 @@ from __future__ import print_function
import collections as collections_lib
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
@@ -3297,9 +3299,131 @@ def count(values,
return count_, update_op
+def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
+ metrics_collections=None, updates_collections=None, name=None):
+ """Calculates Cohen's kappa.
+
+ [Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic
+ that measures inter-annotator agreement.
+
+ The `cohen_kappa` function calculates the confusion matrix, and creates three
+ local variables to compute the Cohen's kappa: `po`, `pe_row`, and `pe_col`,
+ which refer to the diagonal part, rows and columns totals of the confusion
+ matrix, respectively. This value is ultimately returned as `kappa`, an
+ idempotent operation that is calculated by
+
+ pe = (pe_row * pe_col) / N
+ k = (sum(po) - sum(pe)) / (N - sum(pe))
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the
+ `kappa`. `update_op` weights each prediction by the corresponding value in
+ `weights`.
+
+ Class labels are expected to start at 0. E.g., if `num_classes`
+ was three, then the possible labels would be [0, 1, 2].
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ NOTE: Equivalent to `sklearn.metrics.cohen_kappa_score`, but the method
+ doesn't support weighted matrix yet.
+
+ Args:
+ labels: 1-D `Tensor` of real labels for the classification task. Must be
+ one of the following types: int16, int32, int64.
+ predictions_idx: 1-D `Tensor` of predicted class indices for a given
+ classification. Must have the same type as `labels`.
+ num_classes: The possible number of labels.
+ weights: Optional `Tensor` whose shape matches `predictions`.
+ metrics_collections: An optional list of collections that `kappa` should
+ be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ kappa: Scalar float `Tensor` representing the current Cohen's kappa.
+ update_op: `Operation` that increments `po`, `pe_row` and `pe_col`
+ variables appropriately and whose value matches `kappa`.
+
+ Raises:
+ ValueError: If `num_classes` is less than 2, or `predictions` and `labels`
+ have mismatched shapes, or if `weights` is not `None` and its shape
+ doesn't match `predictions`, or if either `metrics_collections` or
+ `updates_collections` are not a list or tuple.
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.in_eager_mode():
+ raise RuntimeError('tf.contrib.metrics.cohen_kappa is not supported'
+ 'when eager execution is enabled.')
+ if num_classes < 2:
+ raise ValueError('`num_classes` must be >= 2.'
+ 'Found: {}'.format(num_classes))
+ with variable_scope.variable_scope(name, 'cohen_kappa',
+ (labels, predictions_idx, weights)):
+ # Convert 2-dim (num, 1) to 1-dim (num,)
+ labels.get_shape().with_rank_at_most(2)
+ if labels.get_shape().ndims == 2:
+ labels = array_ops.squeeze(labels, axis=[-1])
+ predictions_idx, labels, weights = (
+ metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ predictions=predictions_idx, labels=labels, weights=weights))
+ predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape())
+
+ stat_dtype = (dtypes.int64
+ if weights is None or weights.dtype.is_integer
+ else dtypes.float32)
+ po = metrics_impl.metric_variable(
+ (num_classes,), stat_dtype, name='po')
+ pe_row = metrics_impl.metric_variable(
+ (num_classes,), stat_dtype, name='pe_row')
+ pe_col = metrics_impl.metric_variable(
+ (num_classes,), stat_dtype, name='pe_col')
+
+ # Table of the counts of agreement:
+ counts_in_table = confusion_matrix.confusion_matrix(
+ labels, predictions_idx,
+ num_classes=num_classes, weights=weights,
+ dtype=stat_dtype, name="counts_in_table")
+
+ po_t = array_ops.diag_part(counts_in_table)
+ pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0)
+ pe_col_t = math_ops.reduce_sum(counts_in_table, axis=1)
+ update_po = state_ops.assign_add(po, po_t)
+ update_pe_row = state_ops.assign_add(pe_row, pe_row_t)
+ update_pe_col = state_ops.assign_add(pe_col, pe_col_t)
+
+ def _calculate_k(po, pe_row, pe_col, name):
+ po_sum = math_ops.reduce_sum(po)
+ total = math_ops.reduce_sum(pe_row)
+ pe_sum = math_ops.reduce_sum(
+ metrics_impl._safe_div( # pylint: disable=protected-access
+ pe_row * pe_col, total, None))
+ po_sum, pe_sum, total = (math_ops.to_double(po_sum),
+ math_ops.to_double(pe_sum),
+ math_ops.to_double(total))
+ # kappa = (po - pe) / (N - pe)
+ k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access
+ po_sum - pe_sum, total - pe_sum, name=name)
+ return k
+
+ kappa = _calculate_k(po, pe_row, pe_col, name='value')
+ update_op = _calculate_k(update_po, update_pe_row, update_pe_col,
+ name='update_op')
+
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, kappa)
+
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return kappa, update_op
+
+
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
+ 'cohen_kappa',
'count',
'precision_recall_at_equal_thresholds',
'recall_at_precision',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index f05ae394e6..89aa29f711 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -6660,5 +6660,213 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(4.1, result.eval(), 5)
+class CohenKappaTest(test.TestCase):
+
+ def _confusion_matrix_to_samples(self, confusion_matrix):
+ x, y = confusion_matrix.shape
+ pairs = []
+ for label in range(x):
+ for feature in range(y):
+ pairs += [label, feature] * confusion_matrix[label, feature]
+ pairs = np.array(pairs).reshape((-1, 2))
+ return pairs[:, 0], pairs[:, 1]
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testVars(self):
+ metrics.cohen_kappa(
+ predictions_idx=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_classes=2)
+ _assert_metric_variables(self, (
+ 'cohen_kappa/po:0',
+ 'cohen_kappa/pe_row:0',
+ 'cohen_kappa/pe_col:0',))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ kappa, _ = metrics.cohen_kappa(
+ predictions_idx=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_classes=2,
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [kappa])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, update_op = metrics.cohen_kappa(
+ predictions_idx=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_classes=2,
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run(update_op)
+
+ # Then verify idempotency.
+ initial_kappa = kappa.eval()
+ for _ in range(10):
+ self.assertAlmostEqual(initial_kappa, kappa.eval(), 5)
+
+ def testBasic(self):
+ confusion_matrix = np.array([
+ [9, 3, 1],
+ [4, 8, 2],
+ [2, 1, 6]])
+ # overall total = 36
+ # po = [9, 8, 6], sum(po) = 23
+ # pe_row = [15, 12, 9], pe_col = [13, 14, 9], so pe = [5.42, 4.67, 2.25]
+ # finally, kappa = (sum(po) - sum(pe)) / (N - sum(pe))
+ # = (23 - 12.34) / (36 - 12.34)
+ # = 0.45
+ # see: http://psych.unl.edu/psycrs/handcomp/hckappa.PDF
+ expect = 0.45
+ labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
+
+ dtypes = [dtypes_lib.int16, dtypes_lib.int32, dtypes_lib.int64]
+ shapes = [(len(labels,)), # 1-dim
+ (len(labels), 1)] # 2-dim
+ weights = [None, np.ones_like(labels)]
+
+ for dtype in dtypes:
+ for shape in shapes:
+ for weight in weights:
+ with self.test_session() as sess:
+ predictions_tensor = constant_op.constant(
+ np.reshape(predictions, shape), dtype=dtype)
+ labels_tensor = constant_op.constant(
+ np.reshape(labels, shape), dtype=dtype)
+ kappa, update_op = metrics.cohen_kappa(
+ labels_tensor, predictions_tensor, 3, weights=weight)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 2)
+ self.assertAlmostEqual(expect, kappa.eval(), 2)
+
+ def testAllCorrect(self):
+ inputs = np.arange(0, 100) % 4
+ # confusion matrix
+ # [[25, 0, 0],
+ # [0, 25, 0],
+ # [0, 0, 25]]
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs)
+ expect = 1.0
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(inputs)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 5)
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testAllIncorrect(self):
+ labels = np.arange(0, 100) % 4
+ predictions = (labels + 1) % 4
+ # confusion matrix
+ # [[0, 25, 0],
+ # [0, 0, 25],
+ # [25, 0, 0]]
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
+ expect = -0.333333333333
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 5)
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testWeighted(self):
+ confusion_matrix = np.array([
+ [9, 3, 1],
+ [4, 8, 2],
+ [2, 1, 6]])
+ labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
+ num_samples = np.sum(confusion_matrix, dtype=np.int32)
+ weights = (np.arange(0, num_samples) % 5) / 5.0
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
+ # labels, predictions, sample_weight=weights)
+ expect = 0.453466583385
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels)
+ kappa, update_op = metrics.cohen_kappa(labels, predictions, 4,
+ weights=weights)
+
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expect, sess.run(update_op), 5)
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testWithMultipleUpdates(self):
+ confusion_matrix = np.array([
+ [90, 30, 10, 20],
+ [40, 80, 20, 30],
+ [20, 10, 60, 35],
+ [15, 25, 30, 25]])
+ labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
+ num_samples = np.sum(confusion_matrix, dtype=np.int32)
+ weights = (np.arange(0, num_samples) % 5) / 5.0
+ num_classes = confusion_matrix.shape[0]
+
+ batch_size = num_samples // 10
+ predictions_t = array_ops.placeholder(dtypes_lib.float32,
+ shape=(batch_size,))
+ labels_t = array_ops.placeholder(dtypes_lib.int32,
+ shape=(batch_size,))
+ weights_t = array_ops.placeholder(dtypes_lib.float32,
+ shape=(batch_size,))
+ kappa, update_op = metrics.cohen_kappa(
+ labels_t, predictions_t, num_classes, weights=weights_t)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ for idx in range(0, num_samples, batch_size):
+ batch_start, batch_end = idx, idx + batch_size
+ sess.run(update_op,
+ feed_dict={labels_t: labels[batch_start:batch_end],
+ predictions_t: predictions[batch_start:batch_end],
+ weights_t: weights[batch_start:batch_end]})
+ # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
+ # labels_np, predictions_np, sample_weight=weights_np)
+ expect = 0.289965397924
+ self.assertAlmostEqual(expect, kappa.eval(), 5)
+
+ def testInvalidNumClasses(self):
+ predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1))
+ labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1))
+ with self.assertRaisesRegexp(ValueError, 'num_classes'):
+ metrics.cohen_kappa(labels, predictions, 1)
+
+ def testInvalidDimension(self):
+ predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 1))
+ invalid_labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 2))
+ with self.assertRaises(ValueError):
+ metrics.cohen_kappa(invalid_labels, predictions, 3)
+
+ invalid_predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 2))
+ labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1))
+ with self.assertRaises(ValueError):
+ metrics.cohen_kappa(labels, invalid_predictions, 3)
+
+
if __name__ == '__main__':
test.main()