aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py')
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py48
1 files changed, 46 insertions, 2 deletions
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
index 259e62bd86..c95b9da1e4 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -87,6 +87,46 @@ class SparsemaxTest(test.TestCase):
p_sparemax, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+ def _test_sparsemax_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax transfers nan"""
+ z_nan = np.asarray([
+ [0, np.nan, 0],
+ [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan],
+ ]).astype(dtype)
+
+ _, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_nan)
+
+ def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax is infinity safe"""
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, np.inf],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, -np.inf]]).astype(dtype)
+
+ _, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg)
+
+ _, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_pos)
+
+ _, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_mix)
+
+
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 1"""
z = np.zeros((1, 10))
@@ -97,7 +137,7 @@ class SparsemaxTest(test.TestCase):
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
- def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ def _test_sparsemax_of_to_inf(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 2"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -210,10 +250,14 @@ class SparsemaxTest(test.TestCase):
self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
- self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_nan(dtype, random, use_gpu=False)
self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_to_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_permutation(dtype, random, use_gpu=False)