aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/sparsemax/python/ops/sparsemax.py')
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py27
1 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
index e617af2ff1..f79c93f347 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -49,7 +49,14 @@ def sparsemax(logits, name=None):
obs = array_ops.shape(logits)[0]
dims = array_ops.shape(logits)[1]
- z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # The mean(logits) can be substracted from logits to make the algorithm
+ # more numerically stable. the instability in this algorithm comes mostly
+ # from the z_cumsum. Substacting the mean will cause z_cumsum to be close
+ # to zero. However, in practise the numerical instability issues are very
+ # minor and substacting the mean causes extra issues with inf and nan
+ # input.
+ z = logits
# sort z
z_sorted, _ = nn.top_k(z, k=dims)
@@ -64,10 +71,24 @@ def sparsemax(logits, name=None):
k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
# calculate tau(z)
- indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
+ # If there are inf values or all values are -inf, the k_z will be zero,
+ # this is mathematically invalid and will also cause the gather_nd to fail.
+ # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
+ # fixed later (see p_safe) by returning p = nan. This results in the same
+ # behavior as softmax.
+ k_z_safe = math_ops.maximum(k_z, 1)
+ indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1)
tau_sum = array_ops.gather_nd(z_cumsum, indices)
tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
# calculate p
- return math_ops.maximum(
+ p = math_ops.maximum(
math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis])
+ # If k_z = 0 or if z = nan, then the input is invalid
+ p_safe = array_ops.where(
+ math_ops.logical_or(
+ math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])),
+ array_ops.fill([obs, dims], math_ops.cast(float("nan"), logits.dtype)),
+ p)
+
+ return p_safe