aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 10:45:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 10:45:39 -0700
commit0c48c703c3c1455cf3b2c0e47e2108e053ff83e2 (patch)
tree3662951953b290162dc430e61ca12d3af38cc3d5 /tensorflow/contrib/losses
parentbca361df0d02fdf2911dcb2899b0257a1d92f080 (diff)
parente3c334e57fba9afc0b0a3aa5f7787ee35e17ddf6 (diff)
Merge pull request #21798 from facaiy:ENH/div_no_nan_treate_negative_as_zero
PiperOrigin-RevId: 214290400
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py37
1 files changed, 7 insertions, 30 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 651de4e2f4..7e5ab05987 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -66,32 +66,6 @@ def _scale_losses(losses, weights):
return math_ops.reduce_sum(reduced_losses)
-def _safe_div(numerator, denominator, name="value"):
- """Computes a safe divide which returns 0 if the denominator is zero.
-
- Note that the function contains an additional conditional check that is
- necessary for avoiding situations where the loss is zero causing NaNs to
- creep into the gradient computation.
-
- Args:
- numerator: An arbitrary `Tensor`.
- denominator: A `Tensor` whose shape matches `numerator` and whose values are
- assumed to be non-negative.
- name: An optional name for the returned op.
-
- Returns:
- The element-wise value of the numerator divided by the denominator.
- """
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.div(numerator,
- array_ops.where(
- math_ops.equal(denominator, 0),
- array_ops.ones_like(denominator), denominator)),
- array_ops.zeros_like(numerator),
- name=name)
-
-
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
@@ -104,7 +78,7 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
- return _safe_div(total_loss, num_present)
+ return math_ops.div_no_nan(total_loss, num_present, name="value")
@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
@@ -609,11 +583,14 @@ def mean_pairwise_squared_error(predictions,
math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
- term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
+ term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
+ num_present_per_batch,
+ name="value")
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
- term2 = 2.0 * _safe_div(
- math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
+ term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
+ math_ops.square(num_present_per_batch),
+ name="value")
loss = _scale_losses(term1 - term2, weights)