diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 16:16:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 16:16:26 -0700 |
commit | 12443341c1cf1c96fa187ca08dee2f2a9b9f618b (patch) | |
tree | a367663612e29e3d0734a4417a7fc152b3170769 /tensorflow | |
parent | 89c887558d8b0067213c39a79d5d048d3422b6dd (diff) | |
parent | f410ffc1699e864e84857089183db0d952ada7fe (diff) |
Merge pull request #21183 from AndreasMadsen:sparsemax-2
PiperOrigin-RevId: 215981773
Diffstat (limited to 'tensorflow')
4 files changed, 132 insertions, 11 deletions
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py index 360e7dbe75..7743f5b4a7 100644 --- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py +++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py @@ -109,6 +109,42 @@ class SparsemaxLossTest(test.TestCase): np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3) self.assertShapeEqual(np_loss, tf_loss_op) + def _test_sparsemax_loss_of_nan(self, dtype, random, use_gpu): + """check sparsemax-loss transfers nan""" + q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + z_nan = np.asarray([[0, np.nan, 0], [0, np.nan, np.nan], + [np.nan, np.nan, np.nan]]).astype(dtype) + + _, tf_loss_nan = self._tf_sparsemax_loss(z_nan, q, dtype, use_gpu) + self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan], tf_loss_nan) + + def _test_sparsemax_loss_of_inf(self, dtype, random, use_gpu): + """check sparsemax-loss is infinity safe""" + q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]]) + z_neg = np.asarray([ + [0, -np.inf, 0], + [0, -np.inf, -np.inf], + [-np.inf, -np.inf, 0], + [-np.inf, -np.inf, -np.inf], + ]).astype(dtype) + z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, + np.inf], [np.inf, np.inf, 0], + [np.inf, np.inf, np.inf]]).astype(dtype) + z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf], + [-np.inf, np.inf, 0], [-np.inf, np.inf, + -np.inf]]).astype(dtype) + + _, tf_loss_neg = self._tf_sparsemax_loss(z_neg, q, dtype, use_gpu) + self.assertAllCloseAccordingToType([0.25, np.inf, 0, np.nan], tf_loss_neg) + + _, tf_loss_pos = self._tf_sparsemax_loss(z_pos, q, dtype, use_gpu) + self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan], + tf_loss_pos) + + _, tf_loss_mix = self._tf_sparsemax_loss(z_mix, q, dtype, use_gpu) + self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan], + tf_loss_mix) + def _test_constant_add(self, dtype, random, use_gpu): """check sparsemax-loss proposition 3""" z = random.uniform(low=-3, high=3, size=(test_obs, 10)) @@ -198,6 +234,10 @@ class SparsemaxLossTest(test.TestCase): self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False) + self._test_sparsemax_loss_of_nan(dtype, random, use_gpu=False) + + self._test_sparsemax_loss_of_inf(dtype, random, use_gpu=False) + self._test_constant_add(dtype, random, use_gpu=False) self._test_sparsemax_loss_positive(dtype, random, use_gpu=False) diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py index 259e62bd86..c95b9da1e4 100644 --- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py +++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py @@ -87,6 +87,46 @@ class SparsemaxTest(test.TestCase): p_sparemax, tf_sparsemax_out, half_atol=5e-3) self.assertShapeEqual(p_sparemax, tf_sparsemax_op) + def _test_sparsemax_of_nan(self, dtype, random, use_gpu): + """check sparsemax transfers nan""" + z_nan = np.asarray([ + [0, np.nan, 0], + [0, np.nan, np.nan], + [np.nan, np.nan, np.nan], + ]).astype(dtype) + + _, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan]], tf_sparsemax_nan) + + def _test_sparsemax_of_inf(self, dtype, random, use_gpu): + """check sparsemax is infinity safe""" + z_neg = np.asarray([ + [0, -np.inf, 0], + [0, -np.inf, -np.inf], + [-np.inf, -np.inf, -np.inf], + ]).astype(dtype) + z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, np.inf], + [np.inf, np.inf, np.inf]]).astype(dtype) + z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf], + [-np.inf, np.inf, -np.inf]]).astype(dtype) + + _, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg) + + _, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan]], tf_sparsemax_pos) + + _, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan]], tf_sparsemax_mix) + + def _test_sparsemax_of_zero(self, dtype, random, use_gpu): """check sparsemax proposition 1, part 1""" z = np.zeros((1, 10)) @@ -97,7 +137,7 @@ class SparsemaxTest(test.TestCase): self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out) self.assertShapeEqual(p_sparemax, tf_sparsemax_op) - def _test_sparsemax_of_inf(self, dtype, random, use_gpu): + def _test_sparsemax_of_to_inf(self, dtype, random, use_gpu): """check sparsemax proposition 1, part 2""" z = random.uniform(low=-3, high=3, size=(test_obs, 10)) @@ -210,10 +250,14 @@ class SparsemaxTest(test.TestCase): self._test_sparsemax_against_numpy(dtype, random, use_gpu=False) - self._test_sparsemax_of_zero(dtype, random, use_gpu=False) + self._test_sparsemax_of_nan(dtype, random, use_gpu=False) self._test_sparsemax_of_inf(dtype, random, use_gpu=False) + self._test_sparsemax_of_zero(dtype, random, use_gpu=False) + + self._test_sparsemax_of_to_inf(dtype, random, use_gpu=False) + self._test_constant_add(dtype, random, use_gpu=False) self._test_permutation(dtype, random, use_gpu=False) diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py index e617af2ff1..f79c93f347 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -49,7 +49,14 @@ def sparsemax(logits, name=None): obs = array_ops.shape(logits)[0] dims = array_ops.shape(logits)[1] - z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + # In the paper, they call the logits z. + # The mean(logits) can be substracted from logits to make the algorithm + # more numerically stable. the instability in this algorithm comes mostly + # from the z_cumsum. Substacting the mean will cause z_cumsum to be close + # to zero. However, in practise the numerical instability issues are very + # minor and substacting the mean causes extra issues with inf and nan + # input. + z = logits # sort z z_sorted, _ = nn.top_k(z, k=dims) @@ -64,10 +71,24 @@ def sparsemax(logits, name=None): k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1) # calculate tau(z) - indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1) + # If there are inf values or all values are -inf, the k_z will be zero, + # this is mathematically invalid and will also cause the gather_nd to fail. + # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then + # fixed later (see p_safe) by returning p = nan. This results in the same + # behavior as softmax. + k_z_safe = math_ops.maximum(k_z, 1) + indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1) tau_sum = array_ops.gather_nd(z_cumsum, indices) tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype) # calculate p - return math_ops.maximum( + p = math_ops.maximum( math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis]) + # If k_z = 0 or if z = nan, then the input is invalid + p_safe = array_ops.where( + math_ops.logical_or( + math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])), + array_ops.fill([obs, dims], math_ops.cast(float("nan"), logits.dtype)), + p) + + return p_safe diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py index 582d1e6136..c0438f16bc 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py @@ -47,14 +47,30 @@ def sparsemax_loss(logits, sparsemax, labels, name=None): sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax") labels = ops.convert_to_tensor(labels, name="labels") - shifted_logits = logits - \ - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + # In the paper, they call the logits z. + # A constant can be substracted from logits to make the algorithm + # more numerically stable in theory. However, there are really no major + # source numerical instability in this algorithm. + z = logits # sum over support - support = math_ops.cast(sparsemax > 0, sparsemax.dtype) - sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax) + # Use a conditional where instead of a multiplication to support z = -inf. + # If z = -inf, and there is no support (sparsemax = 0), a multiplication + # would cause 0 * -inf = nan, which is not correct in this case. + sum_s = array_ops.where( + math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)), + sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax)) # - z_k + ||q||^2 - q_part = labels * (0.5 * labels - shifted_logits) + q_part = labels * (0.5 * labels - z) + # Fix the case where labels = 0 and z = -inf, where q_part would + # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for + # z = -inf should be consideredself. + # The code below also coveres the case where z = inf. Howeverm in this + # caose the sparsemax will be nan, which means the sum_s will also be nan, + # therefor this case doesn't need addtional special treatment. + q_part_safe = array_ops.where( + math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)), + array_ops.zeros_like(z), q_part) - return math_ops.reduce_sum(sum_s + q_part, axis=1) + return math_ops.reduce_sum(sum_s + q_part_safe, axis=1) |