aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 16:16:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 16:16:26 -0700
commit12443341c1cf1c96fa187ca08dee2f2a9b9f618b (patch)
treea367663612e29e3d0734a4417a7fc152b3170769 /tensorflow/contrib
parent89c887558d8b0067213c39a79d5d048d3422b6dd (diff)
parentf410ffc1699e864e84857089183db0d952ada7fe (diff)
Merge pull request #21183 from AndreasMadsen:sparsemax-2
PiperOrigin-RevId: 215981773
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py40
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py48
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py27
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py28
4 files changed, 132 insertions, 11 deletions
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
index 360e7dbe75..7743f5b4a7 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -109,6 +109,42 @@ class SparsemaxLossTest(test.TestCase):
np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_loss, tf_loss_op)
+ def _test_sparsemax_loss_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax-loss transfers nan"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_nan = np.asarray([[0, np.nan, 0], [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]]).astype(dtype)
+
+ _, tf_loss_nan = self._tf_sparsemax_loss(z_nan, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan], tf_loss_nan)
+
+ def _test_sparsemax_loss_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax-loss is infinity safe"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, 0],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf,
+ np.inf], [np.inf, np.inf, 0],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, 0], [-np.inf, np.inf,
+ -np.inf]]).astype(dtype)
+
+ _, tf_loss_neg = self._tf_sparsemax_loss(z_neg, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([0.25, np.inf, 0, np.nan], tf_loss_neg)
+
+ _, tf_loss_pos = self._tf_sparsemax_loss(z_pos, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_pos)
+
+ _, tf_loss_mix = self._tf_sparsemax_loss(z_mix, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_mix)
+
def _test_constant_add(self, dtype, random, use_gpu):
"""check sparsemax-loss proposition 3"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -198,6 +234,10 @@ class SparsemaxLossTest(test.TestCase):
self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
+ self._test_sparsemax_loss_of_nan(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_of_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
index 259e62bd86..c95b9da1e4 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -87,6 +87,46 @@ class SparsemaxTest(test.TestCase):
p_sparemax, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+ def _test_sparsemax_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax transfers nan"""
+ z_nan = np.asarray([
+ [0, np.nan, 0],
+ [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan],
+ ]).astype(dtype)
+
+ _, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_nan)
+
+ def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax is infinity safe"""
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, np.inf],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, -np.inf]]).astype(dtype)
+
+ _, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg)
+
+ _, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_pos)
+
+ _, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_mix)
+
+
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 1"""
z = np.zeros((1, 10))
@@ -97,7 +137,7 @@ class SparsemaxTest(test.TestCase):
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
- def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ def _test_sparsemax_of_to_inf(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 2"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -210,10 +250,14 @@ class SparsemaxTest(test.TestCase):
self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
- self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_nan(dtype, random, use_gpu=False)
self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_to_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_permutation(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
index e617af2ff1..f79c93f347 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -49,7 +49,14 @@ def sparsemax(logits, name=None):
obs = array_ops.shape(logits)[0]
dims = array_ops.shape(logits)[1]
- z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # The mean(logits) can be substracted from logits to make the algorithm
+ # more numerically stable. the instability in this algorithm comes mostly
+ # from the z_cumsum. Substacting the mean will cause z_cumsum to be close
+ # to zero. However, in practise the numerical instability issues are very
+ # minor and substacting the mean causes extra issues with inf and nan
+ # input.
+ z = logits
# sort z
z_sorted, _ = nn.top_k(z, k=dims)
@@ -64,10 +71,24 @@ def sparsemax(logits, name=None):
k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
# calculate tau(z)
- indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
+ # If there are inf values or all values are -inf, the k_z will be zero,
+ # this is mathematically invalid and will also cause the gather_nd to fail.
+ # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
+ # fixed later (see p_safe) by returning p = nan. This results in the same
+ # behavior as softmax.
+ k_z_safe = math_ops.maximum(k_z, 1)
+ indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1)
tau_sum = array_ops.gather_nd(z_cumsum, indices)
tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
# calculate p
- return math_ops.maximum(
+ p = math_ops.maximum(
math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis])
+ # If k_z = 0 or if z = nan, then the input is invalid
+ p_safe = array_ops.where(
+ math_ops.logical_or(
+ math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])),
+ array_ops.fill([obs, dims], math_ops.cast(float("nan"), logits.dtype)),
+ p)
+
+ return p_safe
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
index 582d1e6136..c0438f16bc 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -47,14 +47,30 @@ def sparsemax_loss(logits, sparsemax, labels, name=None):
sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
labels = ops.convert_to_tensor(labels, name="labels")
- shifted_logits = logits - \
- math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # A constant can be substracted from logits to make the algorithm
+ # more numerically stable in theory. However, there are really no major
+ # source numerical instability in this algorithm.
+ z = logits
# sum over support
- support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
- sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
+ # Use a conditional where instead of a multiplication to support z = -inf.
+ # If z = -inf, and there is no support (sparsemax = 0), a multiplication
+ # would cause 0 * -inf = nan, which is not correct in this case.
+ sum_s = array_ops.where(
+ math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)),
+ sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax))
# - z_k + ||q||^2
- q_part = labels * (0.5 * labels - shifted_logits)
+ q_part = labels * (0.5 * labels - z)
+ # Fix the case where labels = 0 and z = -inf, where q_part would
+ # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
+ # z = -inf should be consideredself.
+ # The code below also coveres the case where z = inf. Howeverm in this
+ # caose the sparsemax will be nan, which means the sum_s will also be nan,
+ # therefor this case doesn't need addtional special treatment.
+ q_part_safe = array_ops.where(
+ math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)),
+ array_ops.zeros_like(z), q_part)
- return math_ops.reduce_sum(sum_s + q_part, axis=1)
+ return math_ops.reduce_sum(sum_s + q_part_safe, axis=1)