diff options
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 9273659a77..10ff4be2dd 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -949,6 +949,12 @@ def mean_iou(labels, cm_diag = math_ops.to_float(array_ops.diag_part(total_cm)) denominator = sum_over_row + sum_over_col - cm_diag + # The mean is only computed over classes that appear in the + # label or prediction tensor. If the denominator is 0, we need to + # ignore the class. + num_valid_entries = math_ops.reduce_sum(math_ops.cast( + math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) + # If the value of the denominator is 0, set it to 1 to avoid # zero division. denominator = array_ops.where( @@ -956,7 +962,13 @@ def mean_iou(labels, denominator, array_ops.ones_like(denominator)) iou = math_ops.div(cm_diag, denominator) - return math_ops.reduce_mean(iou, name=name) + + # If the number of valid entries is 0 (no classes) we return 0. + result = array_ops.where( + math_ops.greater(num_valid_entries, 0), + math_ops.reduce_sum(iou, name=name) / num_valid_entries, + 0) + return result mean_iou_v = compute_mean_iou('mean_iou') |