diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 10:45:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 10:45:39 -0700 |
commit | 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2 (patch) | |
tree | 3662951953b290162dc430e61ca12d3af38cc3d5 /tensorflow/contrib/losses | |
parent | bca361df0d02fdf2911dcb2899b0257a1d92f080 (diff) | |
parent | e3c334e57fba9afc0b0a3aa5f7787ee35e17ddf6 (diff) |
Merge pull request #21798 from facaiy:ENH/div_no_nan_treate_negative_as_zero
PiperOrigin-RevId: 214290400
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops.py | 37 |
1 files changed, 7 insertions, 30 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 651de4e2f4..7e5ab05987 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -66,32 +66,6 @@ def _scale_losses(losses, weights): return math_ops.reduce_sum(reduced_losses) -def _safe_div(numerator, denominator, name="value"): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - name: An optional name for the returned op. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator), - name=name) - - def _safe_mean(losses, num_present): """Computes a safe mean of the losses. @@ -104,7 +78,7 @@ def _safe_mean(losses, num_present): then zero is returned. """ total_loss = math_ops.reduce_sum(losses) - return _safe_div(total_loss, num_present) + return math_ops.div_no_nan(total_loss, num_present, name="value") @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.") @@ -609,11 +583,14 @@ def mean_pairwise_squared_error(predictions, math_ops.square(diffs), reduction_indices=reduction_indices) num_present_per_batch = _num_present(diffs, weights, per_batch=True) - term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch) + term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch, + num_present_per_batch, + name="value") sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) - term2 = 2.0 * _safe_div( - math_ops.square(sum_diff), math_ops.square(num_present_per_batch)) + term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff), + math_ops.square(num_present_per_batch), + name="value") loss = _scale_losses(term1 - term2, weights) |