aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/softmax_op_test.py
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2018-03-21 15:55:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 15:58:16 -0700
commit942a32bc71291994c14625b6311268319dd27808 (patch)
tree1ed34c04d06867fd34ef2dcba46351fb7fe6c5bc /tensorflow/python/kernel_tests/softmax_op_test.py
parent9cd65e9a9081640934b2b78cf84b6e51ddd69796 (diff)
Change Softmax on CUDA to use fp32 for denominator when input/output are fp16.
This avoids potential overflow in the denominator, also makes sure accumulation is done in high precision. PiperOrigin-RevId: 189982655
Diffstat (limited to 'tensorflow/python/kernel_tests/softmax_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py38
1 files changed, 32 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index 2b8e99e18e..981f96b74d 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -18,14 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import unittest
import numpy as np
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
@test_util.with_c_api
@@ -41,9 +44,10 @@ class SoftmaxTest(test.TestCase):
features, axis=dim), one_only_on_dim))
softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim)
if log:
- return np.log(softmax)
+ res = np.log(softmax)
else:
- return softmax
+ res = softmax
+ return res
def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
# A previous version of the code checked the op name rather than the op type
@@ -53,9 +57,9 @@ class SoftmaxTest(test.TestCase):
np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
with self.test_session(use_gpu=use_gpu):
if log:
- tf_softmax = nn_ops.log_softmax(np_features, dim=dim, name=name)
+ tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name)
else:
- tf_softmax = nn_ops.softmax(np_features, dim=dim, name=name)
+ tf_softmax = nn_ops.softmax(np_features, axis=dim, name=name)
out = tf_softmax.eval()
self.assertAllCloseAccordingToType(np_softmax, out)
self.assertShapeEqual(np_softmax, tf_softmax)
@@ -117,10 +121,32 @@ class SoftmaxTest(test.TestCase):
self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32))
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testFloatGPU(self):
+ if test.is_gpu_available(cuda_only=True):
+ rows = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
+ cols = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
+ for row, col in zip(rows, cols):
+ logging.info("Testing softmax float dtype in shape [%d, %d]", row, col)
+ data = np.random.rand(row, col)
+ self._testAll(data.astype(np.float32))
+
def testHalf(self):
self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16))
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testHalfGPU(self):
+ if test.is_gpu_available(cuda_only=True):
+ rows = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
+ cols = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
+ for row, col in zip(rows, cols):
+ logging.info("Testing softmax half dtype in shape [%d, %d]", row, col)
+ data = np.random.rand(row, col)
+ self._testAll(data.astype(np.float16))
+
def testDouble(self):
self._testSoftmax(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64))
@@ -169,7 +195,7 @@ class SoftmaxTest(test.TestCase):
self.assertEqual(0, array_ops.size(x).eval())
# reshape would raise if logits is empty
with self.assertRaises(errors_impl.InvalidArgumentError):
- nn_ops.softmax(x, dim=0).eval()
+ nn_ops.softmax(x, axis=0).eval()
def testDimTooLarge(self):
with self.test_session():
@@ -177,7 +203,7 @@ class SoftmaxTest(test.TestCase):
# inference error.
dim = array_ops.placeholder_with_default(100, shape=[])
with self.assertRaises(errors_impl.InvalidArgumentError):
- nn_ops.softmax([1., 2., 3., 4.], dim=dim).eval()
+ nn_ops.softmax([1., 2., 3., 4.], axis=dim).eval()
def testLargeDims(self):
# Make sure that we properly handle large inputs. See