diff options
author | 2018-09-24 10:45:39 -0700 | |
---|---|---|
committer | 2018-09-24 10:45:39 -0700 | |
commit | 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2 (patch) | |
tree | 3662951953b290162dc430e61ca12d3af38cc3d5 /tensorflow/python/keras/metrics.py | |
parent | bca361df0d02fdf2911dcb2899b0257a1d92f080 (diff) | |
parent | e3c334e57fba9afc0b0a3aa5f7787ee35e17ddf6 (diff) |
Merge pull request #21798 from facaiy:ENH/div_no_nan_treate_negative_as_zero
PiperOrigin-RevId: 214290400
Diffstat (limited to 'tensorflow/python/keras/metrics.py')
-rw-r--r-- | tensorflow/python/keras/metrics.py | 19 |
1 files changed, 1 insertions, 18 deletions
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index e64241e5cf..3df425fd6e 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -155,23 +155,6 @@ def weakmethod(method): return inner -def safe_div(numerator, denominator): - """Divides two tensors element-wise, returning 0 if the denominator is <= 0. - - Args: - numerator: A `Tensor`. - denominator: A `Tensor`, with dtype matching `numerator`. - - Returns: - 0 if `denominator` <= 0, else `numerator` / `denominator` - """ - t = math_ops.truediv(numerator, denominator) - zero = array_ops.zeros_like(t, dtype=denominator.dtype) - condition = math_ops.greater(denominator, zero) - zero = math_ops.cast(zero, t.dtype) - return array_ops.where(condition, t, zero) - - def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): """Squeeze or expand last dimension if needed. @@ -503,7 +486,7 @@ class Mean(Metric): return control_flow_ops.group(update_total_op, update_count_op) def result(self): - return safe_div(self.total, self.count) + return math_ops.div_no_nan(self.total, self.count) class MeanMetricWrapper(Mean): |