diff options
Diffstat (limited to 'tensorflow/python/ops/metrics_impl.py')
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index e04121ee31..25e1613a65 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -175,7 +175,7 @@ def _maybe_expand_labels(labels, predictions): def _safe_div(numerator, denominator, name): - """Divides two values, returning 0 if the denominator is <= 0. + """Divides two tensors element-wise, returning 0 if the denominator is <= 0. Args: numerator: A real `Tensor`. @@ -185,11 +185,11 @@ def _safe_div(numerator, denominator, name): Returns: 0 if `denominator` <= 0, else `numerator` / `denominator` """ - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.truediv(numerator, denominator), - 0, - name=name) + t = math_ops.truediv(numerator, denominator) + zero = array_ops.zeros_like(t, dtype=denominator.dtype) + condition = math_ops.greater(denominator, zero) + zero = math_ops.cast(zero, t.dtype) + return array_ops.where(condition, t, zero, name=name) def _safe_scalar_div(numerator, denominator, name): |