diff options
author | James Qin <jamesqin@google.com> | 2018-03-21 15:55:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-21 15:58:16 -0700 |
commit | 942a32bc71291994c14625b6311268319dd27808 (patch) | |
tree | 1ed34c04d06867fd34ef2dcba46351fb7fe6c5bc /tensorflow/python/kernel_tests/softmax_op_test.py | |
parent | 9cd65e9a9081640934b2b78cf84b6e51ddd69796 (diff) |
Change Softmax on CUDA to use fp32 for denominator when input/output are fp16.
This avoids potential overflow in the denominator, also makes sure accumulation is done
in high precision.
PiperOrigin-RevId: 189982655
Diffstat (limited to 'tensorflow/python/kernel_tests/softmax_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/softmax_op_test.py | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index 2b8e99e18e..981f96b74d 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -18,14 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import unittest import numpy as np + from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging @test_util.with_c_api @@ -41,9 +44,10 @@ class SoftmaxTest(test.TestCase): features, axis=dim), one_only_on_dim)) softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim) if log: - return np.log(softmax) + res = np.log(softmax) else: - return softmax + res = softmax + return res def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False): # A previous version of the code checked the op name rather than the op type @@ -53,9 +57,9 @@ class SoftmaxTest(test.TestCase): np_softmax = self._npSoftmax(np_features, dim=dim, log=log) with self.test_session(use_gpu=use_gpu): if log: - tf_softmax = nn_ops.log_softmax(np_features, dim=dim, name=name) + tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name) else: - tf_softmax = nn_ops.softmax(np_features, dim=dim, name=name) + tf_softmax = nn_ops.softmax(np_features, axis=dim, name=name) out = tf_softmax.eval() self.assertAllCloseAccordingToType(np_softmax, out) self.assertShapeEqual(np_softmax, tf_softmax) @@ -117,10 +121,32 @@ class SoftmaxTest(test.TestCase): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32)) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testFloatGPU(self): + if test.is_gpu_available(cuda_only=True): + rows = [2**x + np.random.randint(0, 1024) for x in range(1, 10)] + cols = [2**x + np.random.randint(0, 1024) for x in range(1, 10)] + for row, col in zip(rows, cols): + logging.info("Testing softmax float dtype in shape [%d, %d]", row, col) + data = np.random.rand(row, col) + self._testAll(data.astype(np.float32)) + def testHalf(self): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16)) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testHalfGPU(self): + if test.is_gpu_available(cuda_only=True): + rows = [2**x + np.random.randint(0, 1024) for x in range(1, 8)] + cols = [2**x + np.random.randint(0, 1024) for x in range(1, 8)] + for row, col in zip(rows, cols): + logging.info("Testing softmax half dtype in shape [%d, %d]", row, col) + data = np.random.rand(row, col) + self._testAll(data.astype(np.float16)) + def testDouble(self): self._testSoftmax( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64)) @@ -169,7 +195,7 @@ class SoftmaxTest(test.TestCase): self.assertEqual(0, array_ops.size(x).eval()) # reshape would raise if logits is empty with self.assertRaises(errors_impl.InvalidArgumentError): - nn_ops.softmax(x, dim=0).eval() + nn_ops.softmax(x, axis=0).eval() def testDimTooLarge(self): with self.test_session(): @@ -177,7 +203,7 @@ class SoftmaxTest(test.TestCase): # inference error. dim = array_ops.placeholder_with_default(100, shape=[]) with self.assertRaises(errors_impl.InvalidArgumentError): - nn_ops.softmax([1., 2., 3., 4.], dim=dim).eval() + nn_ops.softmax([1., 2., 3., 4.], axis=dim).eval() def testLargeDims(self): # Make sure that we properly handle large inputs. See |