diff options
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 56 |
1 files changed, 36 insertions, 20 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 92b3ff2250..2d77e26081 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -831,8 +831,8 @@ def mean_per_class_accuracy(labels, Calculates the accuracy for each class, then takes the mean of that. For estimation of the metric over a stream of data, the function creates an - `update_op` operation that updates these variables and returns the - `mean_accuracy`. + `update_op` operation that updates the accuracy of each class and returns + them. If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. @@ -843,8 +843,8 @@ def mean_per_class_accuracy(labels, shape is [batch size] and type `int32` or `int64`. The tensor will be flattened if its rank > 1. num_classes: The possible number of labels the prediction task can - have. This value must be provided, since a confusion matrix of - dimension = [num_classes, num_classes] will be allocated. + have. This value must be provided, since two variables with shape = + [num_classes] will be allocated. weights: Optional `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must be either `1`, or the same as the corresponding `labels` dimension). @@ -857,7 +857,7 @@ def mean_per_class_accuracy(labels, Returns: mean_accuracy: A `Tensor` representing the mean per class accuracy. - update_op: An operation that increments the confusion matrix. + update_op: An operation that updates the accuracy tensor. Raises: ValueError: If `predictions` and `labels` have mismatched shapes, or if @@ -872,27 +872,43 @@ def mean_per_class_accuracy(labels, with variable_scope.variable_scope(name, 'mean_accuracy', (predictions, labels, weights)): + labels = math_ops.to_int64(labels) + + # Flatten the input if its rank > 1. + if labels.get_shape().ndims > 1: + labels = array_ops.reshape(labels, [-1]) + + if predictions.get_shape().ndims > 1: + predictions = array_ops.reshape(predictions, [-1]) + # Check if shape is compatible. predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - total_cm, update_op = _streaming_confusion_matrix( - labels, predictions, num_classes, weights=weights) + total = metric_variable([num_classes], dtypes.float32, name='total') + count = metric_variable([num_classes], dtypes.float32, name='count') - def compute_mean_accuracy(name): - """Compute the mean per class accuracy via the confusion matrix.""" - per_row_sum = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) - cm_diag = math_ops.to_float(array_ops.diag_part(total_cm)) - denominator = per_row_sum + ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) - # If the value of the denominator is 0, set it to 1 to avoid - # zero division. - denominator = array_ops.where( - math_ops.greater(denominator, 0), denominator, - array_ops.ones_like(denominator)) - accuracies = math_ops.div(cm_diag, denominator) - return math_ops.reduce_mean(accuracies, name=name) + if labels.dtype != predictions.dtype: + predictions = math_ops.cast(predictions, labels.dtype) + is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) + + if weights is not None: + if weights.get_shape().ndims > 1: + weights = array_ops.reshape(weights, [-1]) + weights = math_ops.to_float(weights) + + is_correct = is_correct * weights + ones = ones * weights + + update_total_op = state_ops.scatter_add(total, labels, ones) + update_count_op = state_ops.scatter_add(count, labels, is_correct) + + per_class_accuracy = _safe_div(count, total, None) - mean_accuracy_v = compute_mean_accuracy('mean_accuracy') + mean_accuracy_v = math_ops.reduce_mean(per_class_accuracy, + name='mean_accuracy') + update_op = _safe_div(update_count_op, update_total_op, name='update_op') if metrics_collections: ops.add_to_collections(metrics_collections, mean_accuracy_v) |