aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/metrics_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r--tensorflow/python/ops/metrics_impl.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index e04121ee31..25e1613a65 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -175,7 +175,7 @@ def _maybe_expand_labels(labels, predictions):
def _safe_div(numerator, denominator, name):
- """Divides two values, returning 0 if the denominator is <= 0.
+ """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
Args:
numerator: A real `Tensor`.
@@ -185,11 +185,11 @@ def _safe_div(numerator, denominator, name):
Returns:
0 if `denominator` <= 0, else `numerator` / `denominator`
"""
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.truediv(numerator, denominator),
- 0,
- name=name)
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero, name=name)
def _safe_scalar_div(numerator, denominator, name):