diff options
Diffstat (limited to 'tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py')
-rw-r--r-- | tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py index 582d1e6136..c0438f16bc 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py @@ -47,14 +47,30 @@ def sparsemax_loss(logits, sparsemax, labels, name=None): sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax") labels = ops.convert_to_tensor(labels, name="labels") - shifted_logits = logits - \ - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + # In the paper, they call the logits z. + # A constant can be substracted from logits to make the algorithm + # more numerically stable in theory. However, there are really no major + # source numerical instability in this algorithm. + z = logits # sum over support - support = math_ops.cast(sparsemax > 0, sparsemax.dtype) - sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax) + # Use a conditional where instead of a multiplication to support z = -inf. + # If z = -inf, and there is no support (sparsemax = 0), a multiplication + # would cause 0 * -inf = nan, which is not correct in this case. + sum_s = array_ops.where( + math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)), + sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax)) # - z_k + ||q||^2 - q_part = labels * (0.5 * labels - shifted_logits) + q_part = labels * (0.5 * labels - z) + # Fix the case where labels = 0 and z = -inf, where q_part would + # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for + # z = -inf should be consideredself. + # The code below also coveres the case where z = inf. Howeverm in this + # caose the sparsemax will be nan, which means the sum_s will also be nan, + # therefor this case doesn't need addtional special treatment. + q_part_safe = array_ops.where( + math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)), + array_ops.zeros_like(z), q_part) - return math_ops.reduce_sum(sum_s + q_part, axis=1) + return math_ops.reduce_sum(sum_s + q_part_safe, axis=1) |