aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/metrics/python/ops/confusion_matrix_ops.py
blob: 89d7a344033fa27eaed23cfbe7fcbe134ba5e9f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops


"""Confusion matrix related metrics."""


def confusion_matrix(predictions, targets, num_classes=None, name=None):
  """Computes the confusion matrix from predictions and targets

  Calculate the Confusion Matrix for a pair of prediction and
  target 1-D int arrays.

  Considering a prediction array such as: `[1, 2, 3]`
  And a target array such as: `[2, 2, 3]`

  The confusion matrix returned would be the following one:
      [[0, 0, 0]
       [0, 1, 0]
       [0, 1, 0]
       [0, 0, 1]]

  Where the matrix rows represent the prediction labels and the columns
  represents the target labels. The confusion matrix is always a 2-D array
  of shape [n, n], where n is the number of valid labels for a given
  classification task. Both prediction and target must be 1-D arrays of
  the same shape in order for this function to work.

  Args:
    predictions: A 1-D array represeting the predictions for a given
                 classification.
    targets: A 1-D represeting the real labels for the classification task.
    num_classes: The possible number of labels the classification task can
                 have. If this value is not provided, it will be calculated
                 using both predictions and targets array.

  Returns:
    A l X l matrix represeting the confusion matrix, where l in the number of
    possible labels in the classification task.

  Raises:
    ValueError: If both predictions and targets are not 1-D vectors and do not
                have the same size.
  """
  with ops.op_scope([predictions, targets, num_classes], name,
                    'confusion_matrix') as name:
    predictions = ops.convert_to_tensor(
      predictions, name="predictions", dtype=dtypes.int64)
    targets = ops.convert_to_tensor(
      targets, name="targets", dtype=dtypes.int64)

    if num_classes is None:
      num_classes = math_ops.maximum(math_ops.reduce_max(predictions),
                                     math_ops.reduce_max(targets)) + 1

    shape = array_ops.pack([num_classes, num_classes])
    indices = array_ops.transpose(
      array_ops.pack([predictions, targets]))
    values = array_ops.ones_like(predictions, dtype=dtypes.int32)
    cm_sparse = ops.SparseTensor(
      indices=indices, values=values, shape=shape)
    zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtypes.int32)

    return sparse_ops.sparse_add(zero_matrix, cm_sparse)