diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-24 16:13:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 16:19:42 -0700 |
commit | 6c40bc717442d56f0b6a60658b05f0549afd69ee (patch) | |
tree | 2d9a179e074e6d0ed7beec2ff3f14f0796bc0107 /tensorflow/contrib/losses | |
parent | d25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (diff) |
BEGIN_PUBLIC
Temporary rollback to fix forward compatibility.
END_PUBLIC
Automated rollback of commit 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2. Revert #21798.
PiperOrigin-RevId: 214349479
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops.py | 37 |
1 files changed, 30 insertions, 7 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 7e5ab05987..651de4e2f4 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -66,6 +66,32 @@ def _scale_losses(losses, weights): return math_ops.reduce_sum(reduced_losses) +def _safe_div(numerator, denominator, name="value"): + """Computes a safe divide which returns 0 if the denominator is zero. + + Note that the function contains an additional conditional check that is + necessary for avoiding situations where the loss is zero causing NaNs to + creep into the gradient computation. + + Args: + numerator: An arbitrary `Tensor`. + denominator: A `Tensor` whose shape matches `numerator` and whose values are + assumed to be non-negative. + name: An optional name for the returned op. + + Returns: + The element-wise value of the numerator divided by the denominator. + """ + return array_ops.where( + math_ops.greater(denominator, 0), + math_ops.div(numerator, + array_ops.where( + math_ops.equal(denominator, 0), + array_ops.ones_like(denominator), denominator)), + array_ops.zeros_like(numerator), + name=name) + + def _safe_mean(losses, num_present): """Computes a safe mean of the losses. @@ -78,7 +104,7 @@ def _safe_mean(losses, num_present): then zero is returned. """ total_loss = math_ops.reduce_sum(losses) - return math_ops.div_no_nan(total_loss, num_present, name="value") + return _safe_div(total_loss, num_present) @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.") @@ -583,14 +609,11 @@ def mean_pairwise_squared_error(predictions, math_ops.square(diffs), reduction_indices=reduction_indices) num_present_per_batch = _num_present(diffs, weights, per_batch=True) - term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch, - num_present_per_batch, - name="value") + term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch) sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) - term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch), - name="value") + term2 = 2.0 * _safe_div( + math_ops.square(sum_diff), math_ops.square(num_present_per_batch)) loss = _scale_losses(term1 - term2, weights) |