aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py')
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
index 582d1e6136..c0438f16bc 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -47,14 +47,30 @@ def sparsemax_loss(logits, sparsemax, labels, name=None):
sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
labels = ops.convert_to_tensor(labels, name="labels")
- shifted_logits = logits - \
- math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # A constant can be substracted from logits to make the algorithm
+ # more numerically stable in theory. However, there are really no major
+ # source numerical instability in this algorithm.
+ z = logits
# sum over support
- support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
- sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
+ # Use a conditional where instead of a multiplication to support z = -inf.
+ # If z = -inf, and there is no support (sparsemax = 0), a multiplication
+ # would cause 0 * -inf = nan, which is not correct in this case.
+ sum_s = array_ops.where(
+ math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)),
+ sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax))
# - z_k + ||q||^2
- q_part = labels * (0.5 * labels - shifted_logits)
+ q_part = labels * (0.5 * labels - z)
+ # Fix the case where labels = 0 and z = -inf, where q_part would
+ # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
+ # z = -inf should be consideredself.
+ # The code below also coveres the case where z = inf. Howeverm in this
+ # caose the sparsemax will be nan, which means the sum_s will also be nan,
+ # therefor this case doesn't need addtional special treatment.
+ q_part_safe = array_ops.where(
+ math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)),
+ array_ops.zeros_like(z), q_part)
- return math_ops.reduce_sum(sum_s + q_part, axis=1)
+ return math_ops.reduce_sum(sum_s + q_part_safe, axis=1)