aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-24 16:13:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 16:19:42 -0700
commit6c40bc717442d56f0b6a60658b05f0549afd69ee (patch)
tree2d9a179e074e6d0ed7beec2ff3f14f0796bc0107 /tensorflow/contrib/losses
parentd25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (diff)
BEGIN_PUBLIC
Temporary rollback to fix forward compatibility. END_PUBLIC Automated rollback of commit 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2. Revert #21798. PiperOrigin-RevId: 214349479
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py37
1 files changed, 30 insertions, 7 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 7e5ab05987..651de4e2f4 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -66,6 +66,32 @@ def _scale_losses(losses, weights):
return math_ops.reduce_sum(reduced_losses)
+def _safe_div(numerator, denominator, name="value"):
+ """Computes a safe divide which returns 0 if the denominator is zero.
+
+ Note that the function contains an additional conditional check that is
+ necessary for avoiding situations where the loss is zero causing NaNs to
+ creep into the gradient computation.
+
+ Args:
+ numerator: An arbitrary `Tensor`.
+ denominator: A `Tensor` whose shape matches `numerator` and whose values are
+ assumed to be non-negative.
+ name: An optional name for the returned op.
+
+ Returns:
+ The element-wise value of the numerator divided by the denominator.
+ """
+ return array_ops.where(
+ math_ops.greater(denominator, 0),
+ math_ops.div(numerator,
+ array_ops.where(
+ math_ops.equal(denominator, 0),
+ array_ops.ones_like(denominator), denominator)),
+ array_ops.zeros_like(numerator),
+ name=name)
+
+
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
@@ -78,7 +104,7 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
- return math_ops.div_no_nan(total_loss, num_present, name="value")
+ return _safe_div(total_loss, num_present)
@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
@@ -583,14 +609,11 @@ def mean_pairwise_squared_error(predictions,
math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
- term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
- num_present_per_batch,
- name="value")
+ term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
- term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
- math_ops.square(num_present_per_batch),
- name="value")
+ term2 = 2.0 * _safe_div(
+ math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
loss = _scale_losses(term1 - term2, weights)