aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/metrics_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r--tensorflow/python/ops/metrics_impl.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 9273659a77..10ff4be2dd 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -949,6 +949,12 @@ def mean_iou(labels,
cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
denominator = sum_over_row + sum_over_col - cm_diag
+ # The mean is only computed over classes that appear in the
+ # label or prediction tensor. If the denominator is 0, we need to
+ # ignore the class.
+ num_valid_entries = math_ops.reduce_sum(math_ops.cast(
+ math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
+
# If the value of the denominator is 0, set it to 1 to avoid
# zero division.
denominator = array_ops.where(
@@ -956,7 +962,13 @@ def mean_iou(labels,
denominator,
array_ops.ones_like(denominator))
iou = math_ops.div(cm_diag, denominator)
- return math_ops.reduce_mean(iou, name=name)
+
+ # If the number of valid entries is 0 (no classes) we return 0.
+ result = array_ops.where(
+ math_ops.greater(num_valid_entries, 0),
+ math_ops.reduce_sum(iou, name=name) / num_valid_entries,
+ 0)
+ return result
mean_iou_v = compute_mean_iou('mean_iou')