aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/metrics_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r--tensorflow/python/ops/metrics_impl.py56
1 files changed, 36 insertions, 20 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 92b3ff2250..2d77e26081 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -831,8 +831,8 @@ def mean_per_class_accuracy(labels,
Calculates the accuracy for each class, then takes the mean of that.
For estimation of the metric over a stream of data, the function creates an
- `update_op` operation that updates these variables and returns the
- `mean_accuracy`.
+ `update_op` operation that updates the accuracy of each class and returns
+ them.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
@@ -843,8 +843,8 @@ def mean_per_class_accuracy(labels,
shape is [batch size] and type `int32` or `int64`. The tensor will be
flattened if its rank > 1.
num_classes: The possible number of labels the prediction task can
- have. This value must be provided, since a confusion matrix of
- dimension = [num_classes, num_classes] will be allocated.
+ have. This value must be provided, since two variables with shape =
+ [num_classes] will be allocated.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `labels` dimension).
@@ -857,7 +857,7 @@ def mean_per_class_accuracy(labels,
Returns:
mean_accuracy: A `Tensor` representing the mean per class accuracy.
- update_op: An operation that increments the confusion matrix.
+ update_op: An operation that updates the accuracy tensor.
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
@@ -872,27 +872,43 @@ def mean_per_class_accuracy(labels,
with variable_scope.variable_scope(name, 'mean_accuracy',
(predictions, labels, weights)):
+ labels = math_ops.to_int64(labels)
+
+ # Flatten the input if its rank > 1.
+ if labels.get_shape().ndims > 1:
+ labels = array_ops.reshape(labels, [-1])
+
+ if predictions.get_shape().ndims > 1:
+ predictions = array_ops.reshape(predictions, [-1])
+
# Check if shape is compatible.
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
- total_cm, update_op = _streaming_confusion_matrix(
- labels, predictions, num_classes, weights=weights)
+ total = metric_variable([num_classes], dtypes.float32, name='total')
+ count = metric_variable([num_classes], dtypes.float32, name='count')
- def compute_mean_accuracy(name):
- """Compute the mean per class accuracy via the confusion matrix."""
- per_row_sum = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
- cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
- denominator = per_row_sum
+ ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
- # If the value of the denominator is 0, set it to 1 to avoid
- # zero division.
- denominator = array_ops.where(
- math_ops.greater(denominator, 0), denominator,
- array_ops.ones_like(denominator))
- accuracies = math_ops.div(cm_diag, denominator)
- return math_ops.reduce_mean(accuracies, name=name)
+ if labels.dtype != predictions.dtype:
+ predictions = math_ops.cast(predictions, labels.dtype)
+ is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
+
+ if weights is not None:
+ if weights.get_shape().ndims > 1:
+ weights = array_ops.reshape(weights, [-1])
+ weights = math_ops.to_float(weights)
+
+ is_correct = is_correct * weights
+ ones = ones * weights
+
+ update_total_op = state_ops.scatter_add(total, labels, ones)
+ update_count_op = state_ops.scatter_add(count, labels, is_correct)
+
+ per_class_accuracy = _safe_div(count, total, None)
- mean_accuracy_v = compute_mean_accuracy('mean_accuracy')
+ mean_accuracy_v = math_ops.reduce_mean(per_class_accuracy,
+ name='mean_accuracy')
+ update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, mean_accuracy_v)