aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-22 18:13:37 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-22 23:09:49 +0800
commitc05bb4efcaf53d4cbc315ef6d12de822f2557a13 (patch)
tree72be2dc5de6040aa336ddc03cf04b9fccc19be9a /tensorflow/contrib/losses
parent56ea7fc45559f372315b2aedd0a2df15113f5f93 (diff)
CLN: replace safe_div method by div_no_nan
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py40
1 files changed, 10 insertions, 30 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 651de4e2f4..29f7953c3b 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -66,32 +66,6 @@ def _scale_losses(losses, weights):
return math_ops.reduce_sum(reduced_losses)
-def _safe_div(numerator, denominator, name="value"):
- """Computes a safe divide which returns 0 if the denominator is zero.
-
- Note that the function contains an additional conditional check that is
- necessary for avoiding situations where the loss is zero causing NaNs to
- creep into the gradient computation.
-
- Args:
- numerator: An arbitrary `Tensor`.
- denominator: A `Tensor` whose shape matches `numerator` and whose values are
- assumed to be non-negative.
- name: An optional name for the returned op.
-
- Returns:
- The element-wise value of the numerator divided by the denominator.
- """
- return array_ops.where(
- math_ops.greater(denominator, 0),
- math_ops.div(numerator,
- array_ops.where(
- math_ops.equal(denominator, 0),
- array_ops.ones_like(denominator), denominator)),
- array_ops.zeros_like(numerator),
- name=name)
-
-
def _safe_mean(losses, num_present):
"""Computes a safe mean of the losses.
@@ -104,7 +78,8 @@ def _safe_mean(losses, num_present):
then zero is returned.
"""
total_loss = math_ops.reduce_sum(losses)
- return _safe_div(total_loss, num_present)
+ return math_ops.div_no_nan(total_loss, num_present,
+ negative_to_zero=True, name="value")
@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
@@ -609,11 +584,16 @@ def mean_pairwise_squared_error(predictions,
math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
- term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
+ term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch,
+ num_present_per_batch,
+ negative_to_zero=True,
+ name="value")
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
- term2 = 2.0 * _safe_div(
- math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
+ term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff),
+ math_ops.square(num_present_per_batch),
+ negative_to_zero=True,
+ name="value")
loss = _scale_losses(term1 - term2, weights)