diff options
author | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-08-22 18:13:37 +0800 |
---|---|---|
committer | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-08-22 23:09:49 +0800 |
commit | c05bb4efcaf53d4cbc315ef6d12de822f2557a13 (patch) | |
tree | 72be2dc5de6040aa336ddc03cf04b9fccc19be9a /tensorflow/contrib/losses | |
parent | 56ea7fc45559f372315b2aedd0a2df15113f5f93 (diff) |
CLN: replace safe_div method by div_no_nan
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops.py | 40 |
1 files changed, 10 insertions, 30 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 651de4e2f4..29f7953c3b 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -66,32 +66,6 @@ def _scale_losses(losses, weights): return math_ops.reduce_sum(reduced_losses) -def _safe_div(numerator, denominator, name="value"): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - name: An optional name for the returned op. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator), - name=name) - - def _safe_mean(losses, num_present): """Computes a safe mean of the losses. @@ -104,7 +78,8 @@ def _safe_mean(losses, num_present): then zero is returned. """ total_loss = math_ops.reduce_sum(losses) - return _safe_div(total_loss, num_present) + return math_ops.div_no_nan(total_loss, num_present, + negative_to_zero=True, name="value") @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.") @@ -609,11 +584,16 @@ def mean_pairwise_squared_error(predictions, math_ops.square(diffs), reduction_indices=reduction_indices) num_present_per_batch = _num_present(diffs, weights, per_batch=True) - term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch) + term1 = 2.0 * math_ops.div_no_nan(sum_squares_diff_per_batch, + num_present_per_batch, + negative_to_zero=True, + name="value") sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) - term2 = 2.0 * _safe_div( - math_ops.square(sum_diff), math_ops.square(num_present_per_batch)) + term2 = 2.0 * math_ops.div_no_nan(math_ops.square(sum_diff), + math_ops.square(num_present_per_batch), + negative_to_zero=True, + name="value") loss = _scale_losses(term1 - term2, weights) |