diff options
Diffstat (limited to 'tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py')
-rw-r--r-- | tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py | 48 |
1 files changed, 46 insertions, 2 deletions
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py index 259e62bd86..c95b9da1e4 100644 --- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py +++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py @@ -87,6 +87,46 @@ class SparsemaxTest(test.TestCase): p_sparemax, tf_sparsemax_out, half_atol=5e-3) self.assertShapeEqual(p_sparemax, tf_sparsemax_op) + def _test_sparsemax_of_nan(self, dtype, random, use_gpu): + """check sparsemax transfers nan""" + z_nan = np.asarray([ + [0, np.nan, 0], + [0, np.nan, np.nan], + [np.nan, np.nan, np.nan], + ]).astype(dtype) + + _, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan]], tf_sparsemax_nan) + + def _test_sparsemax_of_inf(self, dtype, random, use_gpu): + """check sparsemax is infinity safe""" + z_neg = np.asarray([ + [0, -np.inf, 0], + [0, -np.inf, -np.inf], + [-np.inf, -np.inf, -np.inf], + ]).astype(dtype) + z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, np.inf], + [np.inf, np.inf, np.inf]]).astype(dtype) + z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf], + [-np.inf, np.inf, -np.inf]]).astype(dtype) + + _, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg) + + _, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan]], tf_sparsemax_pos) + + _, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype, use_gpu) + self.assertAllCloseAccordingToType( + [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan]], tf_sparsemax_mix) + + def _test_sparsemax_of_zero(self, dtype, random, use_gpu): """check sparsemax proposition 1, part 1""" z = np.zeros((1, 10)) @@ -97,7 +137,7 @@ class SparsemaxTest(test.TestCase): self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out) self.assertShapeEqual(p_sparemax, tf_sparsemax_op) - def _test_sparsemax_of_inf(self, dtype, random, use_gpu): + def _test_sparsemax_of_to_inf(self, dtype, random, use_gpu): """check sparsemax proposition 1, part 2""" z = random.uniform(low=-3, high=3, size=(test_obs, 10)) @@ -210,10 +250,14 @@ class SparsemaxTest(test.TestCase): self._test_sparsemax_against_numpy(dtype, random, use_gpu=False) - self._test_sparsemax_of_zero(dtype, random, use_gpu=False) + self._test_sparsemax_of_nan(dtype, random, use_gpu=False) self._test_sparsemax_of_inf(dtype, random, use_gpu=False) + self._test_sparsemax_of_zero(dtype, random, use_gpu=False) + + self._test_sparsemax_of_to_inf(dtype, random, use_gpu=False) + self._test_constant_add(dtype, random, use_gpu=False) self._test_permutation(dtype, random, use_gpu=False) |