diff options
Diffstat (limited to 'tensorflow/contrib/sparsemax/python/ops/sparsemax.py')
-rw-r--r-- | tensorflow/contrib/sparsemax/python/ops/sparsemax.py | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py index e617af2ff1..f79c93f347 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -49,7 +49,14 @@ def sparsemax(logits, name=None): obs = array_ops.shape(logits)[0] dims = array_ops.shape(logits)[1] - z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + # In the paper, they call the logits z. + # The mean(logits) can be substracted from logits to make the algorithm + # more numerically stable. the instability in this algorithm comes mostly + # from the z_cumsum. Substacting the mean will cause z_cumsum to be close + # to zero. However, in practise the numerical instability issues are very + # minor and substacting the mean causes extra issues with inf and nan + # input. + z = logits # sort z z_sorted, _ = nn.top_k(z, k=dims) @@ -64,10 +71,24 @@ def sparsemax(logits, name=None): k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1) # calculate tau(z) - indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1) + # If there are inf values or all values are -inf, the k_z will be zero, + # this is mathematically invalid and will also cause the gather_nd to fail. + # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then + # fixed later (see p_safe) by returning p = nan. This results in the same + # behavior as softmax. + k_z_safe = math_ops.maximum(k_z, 1) + indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1) tau_sum = array_ops.gather_nd(z_cumsum, indices) tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype) # calculate p - return math_ops.maximum( + p = math_ops.maximum( math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis]) + # If k_z = 0 or if z = nan, then the input is invalid + p_safe = array_ops.where( + math_ops.logical_or( + math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])), + array_ops.fill([obs, dims], math_ops.cast(float("nan"), logits.dtype)), + p) + + return p_safe |