aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/softmax_op_gpu.cu.cc90
-rw-r--r--tensorflow/python/framework/test_util.py75
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py38
4 files changed, 145 insertions, 60 deletions
diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
index 1f4a82a733..130d693dbd 100644
--- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
@@ -33,8 +33,42 @@ namespace tensorflow {
namespace {
+template <typename U, typename T>
+__device__ __host__ EIGEN_STRONG_INLINE
+ typename std::enable_if<!std::is_same<T, U>::value, U>::type
+ strict_cast(T t);
+
+template <typename U, typename T>
+__device__ __host__ EIGEN_STRONG_INLINE
+ typename std::enable_if<std::is_same<T, U>::value, U>::type
+ strict_cast(T t) {
+ return t;
+}
+
+template <>
+__device__ __host__ EIGEN_STRONG_INLINE float strict_cast<float, Eigen::half>(
+ Eigen::half t) {
+ return functor::HalfToFloat()(t);
+}
+
+template <>
+__device__ __host__ EIGEN_STRONG_INLINE Eigen::half
+strict_cast<Eigen::half, float>(float t) {
+ return functor::FloatToHalf()(t);
+}
+
template <typename T>
-__global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs,
+struct softmax_traits {
+ using accumulator_type = T;
+};
+
+template <>
+struct softmax_traits<Eigen::half> {
+ using accumulator_type = float;
+};
+
+template <typename T, typename U>
+__global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs,
const T* max_logits, T* output,
const int num_rows, const int num_cols,
const bool in_log_space) {
@@ -43,25 +77,33 @@ __global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs,
const int row = tid / num_cols;
const int col = tid % num_cols;
+ // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
+ U input = strict_cast<U>(logits[tid]);
+ U max_val = strict_cast<U>(ldg(max_logits + row));
+ U result;
+
if (row < num_rows && col < num_cols) {
- if (in_log_space)
- output[tid] =
- logits[tid] - ldg(max_logits + row) - log(ldg(sum_probs + row));
- else
- output[tid] =
- exp(logits[tid] - ldg(max_logits + row)) / ldg(sum_probs + row);
+ if (in_log_space) {
+ result = input - max_val - log(ldg(sum_probs + row));
+ } else {
+ result = exp(input - max_val) / ldg(sum_probs + row);
+ }
+ output[tid] = strict_cast<T>(result);
}
}
-template <typename T>
+template <typename T, typename U>
struct SubtractAndExpFunctor {
__host__ __device__ SubtractAndExpFunctor(const T* logits,
const T* max_logits,
const int num_cols)
: logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {}
- __host__ __device__ T operator()(const int gid) const {
- return exp(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
+ __host__ __device__ U operator()(const int gid) const {
+ // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
+ const U diff =
+ strict_cast<U>(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
+ return exp(diff);
}
const T* logits_;
@@ -80,7 +122,6 @@ void DoRowReduction(OpKernelContext* context, T* output, InputIter input,
functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>(
context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
}
-
} // namespace
template <typename T>
@@ -108,8 +149,10 @@ class SoftmaxOpGPU : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::value,
softmax_out->shape(), &max_logits));
+
+ typedef typename softmax_traits<T>::accumulator_type acc_type;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
+ context->allocate_temp(DataTypeToEnum<acc_type>::value,
softmax_out->shape(), &sum_probs));
DoRowReduction<T, cub::Max, const T*>(
@@ -120,25 +163,28 @@ class SoftmaxOpGPU : public OpKernel {
const int numBlocks = Eigen::divup(rows * cols, numThreads);
cub::CountingInputIterator<int> counting_iterator(0);
- typedef cub::TransformInputIterator<T, SubtractAndExpFunctor<T>,
+ typedef cub::TransformInputIterator<acc_type,
+ SubtractAndExpFunctor<T, acc_type>,
cub::CountingInputIterator<int>>
InputIterType;
InputIterType input_itr(
counting_iterator,
- SubtractAndExpFunctor<T>(
+ SubtractAndExpFunctor<T, acc_type>(
reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols));
- DoRowReduction<T, cub::Sum, InputIterType>(
- context, const_cast<T*>(sum_probs.flat<T>().data()), input_itr, rows,
- cols);
+ DoRowReduction<acc_type, cub::Sum, InputIterType>(
+ context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()),
+ input_itr, rows, cols);
- GenerateNormalizedProb<<<numBlocks, numThreads, 0, cu_stream>>>(
- reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
- reinterpret_cast<const T*>(sum_probs.flat<T>().data()),
- reinterpret_cast<const T*>(max_logits.flat<T>().data()),
- const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_);
+ GenerateNormalizedProb<T, acc_type>
+ <<<numBlocks, numThreads, 0, cu_stream>>>(
+ reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
+ reinterpret_cast<const acc_type*>(
+ sum_probs.flat<acc_type>().data()),
+ reinterpret_cast<const T*>(max_logits.flat<T>().data()),
+ const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_);
}
}
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index d8f8569939..43106b6e59 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -53,6 +53,7 @@ from tensorflow.python.eager import tape # pylint: disable=unused-import
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@@ -201,6 +202,7 @@ def _strip_checkpoint_v2_randomized(graph_def):
def IsGoogleCudaEnabled():
return pywrap_tensorflow.IsGoogleCudaEnabled()
+
def CudaSupportsHalfMatMulAndConv():
return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
@@ -335,6 +337,8 @@ def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
# Make sure default graph reflects prev_value in case next test doesn't call
# reset_default_graph().
ops.reset_default_graph()
+
+
# pylint: disable=protected-access
@@ -451,7 +455,8 @@ def with_c_api(cls):
# If the C API is already enabled, don't do anything. Some tests break if the
# same test is run twice, so this allows us to turn on the C API by default
# without breaking these tests.
- if ops._USE_C_API: return cls
+ if ops._USE_C_API:
+ return cls
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
@@ -469,6 +474,7 @@ def assert_no_new_pyobjects_executing_eagerly(f):
Useful for checking that there are no missing Py_DECREFs in the C exercised by
a bit of Python.
"""
+
def decorator(self, **kwargs):
"""Warms up, gets an object count, runs the test, checks for new objects."""
with context.eager_mode():
@@ -483,8 +489,10 @@ def assert_no_new_pyobjects_executing_eagerly(f):
new_count = len(gc.get_objects())
self.assertEqual(previous_count, new_count)
gc.enable()
+
return decorator
+
def assert_no_new_tensors(f):
"""Decorator for asserting that no new Tensors persist after a test.
@@ -508,17 +516,15 @@ def assert_no_new_tensors(f):
def _is_tensorflow_object(obj):
try:
- return isinstance(obj, (
- ops.Tensor,
- variables.Variable,
- tensor_shape.Dimension,
- tensor_shape.TensorShape))
+ return isinstance(obj,
+ (ops.Tensor, variables.Variable,
+ tensor_shape.Dimension, tensor_shape.TensorShape))
except ReferenceError:
# If the object no longer exists, we don't care about it.
return False
- tensors_before = set(id(obj) for obj in gc.get_objects()
- if _is_tensorflow_object(obj))
+ tensors_before = set(
+ id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
outside_graph_key = ops.get_default_graph()._graph_key
with ops.Graph().as_default():
# Run the test in a new graph so that collections get cleared when it's
@@ -572,18 +578,18 @@ def assert_no_garbage_created(f):
"likely due to a reference cycle. New objects in cycle(s):")
for i, obj in enumerate(gc.garbage[previous_garbage:]):
try:
- logging.error(
- "Object %d of %d" % (i, len(gc.garbage) - previous_garbage))
+ logging.error("Object %d of %d", i,
+ len(gc.garbage) - previous_garbage)
+
def _safe_object_str(obj):
return "<%s %d>" % (obj.__class__.__name__, id(obj))
- logging.error(" Object type: %s" % (_safe_object_str(obj),))
- logging.error(" Referrer types: %s" % (
- ', '.join([_safe_object_str(ref)
- for ref in gc.get_referrers(obj)]),))
- logging.error(" Referent types: %s" % (
- ', '.join([_safe_object_str(ref)
- for ref in gc.get_referents(obj)]),))
- logging.error(" Object attribute names: %s" % (dir(obj),))
+
+ logging.error(" Object type: %s", _safe_object_str(obj))
+ logging.error(" Referrer types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
+ logging.error(" Referent types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
+ logging.error(" Object attribute names: %s", dir(obj))
logging.error(" Object __str__:")
logging.error(obj)
logging.error(" Object __repr__:")
@@ -705,15 +711,23 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
return 0, 0
return int(match.group(1)), int(match.group(2))
- for local_device in device_lib.list_local_devices():
- if local_device.device_type == "GPU":
- if (min_cuda_compute_capability is None or
- compute_capability_from_device_desc(local_device.physical_device_desc)
- >= min_cuda_compute_capability):
+ try:
+ for local_device in device_lib.list_local_devices():
+ if local_device.device_type == "GPU":
+ if (min_cuda_compute_capability is None or
+ compute_capability_from_device_desc(
+ local_device.physical_device_desc) >=
+ min_cuda_compute_capability):
+ return True
+ if local_device.device_type == "SYCL" and not cuda_only:
return True
- if local_device.device_type == "SYCL" and not cuda_only:
- return True
- return False
+ return False
+ except errors_impl.NotFoundError as e:
+ if not all([x in str(e) for x in ["CUDA", "not find"]]):
+ raise e
+ else:
+ logging.error(str(e))
+ return False
@contextlib.contextmanager
@@ -1256,9 +1270,9 @@ class TensorFlowTestCase(googletest.TestCase):
msg="Mismatched value: a%s is different from b%s." % (path_str,
path_str))
except TypeError as e:
- msg = "Error: a%s has %s, but b%s has %s" % (
- path_str, type(a), path_str, type(b))
- e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:])
+ msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
+ path_str, type(b))
+ e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
raise
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
@@ -1438,8 +1452,7 @@ class TensorFlowTestCase(googletest.TestCase):
"""
device1 = pydev.canonical_name(device1)
device2 = pydev.canonical_name(device2)
- self.assertEqual(device1, device2,
- "Devices %s and %s are not equal. %s" %
+ self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" %
(device1, device2, msg))
# Fix Python 3 compatibility issues
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index d9571fa2be..ece1da0332 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1910,7 +1910,7 @@ cuda_py_test(
cuda_py_test(
name = "softmax_op_test",
- size = "small",
+ size = "medium",
srcs = ["softmax_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index 2b8e99e18e..981f96b74d 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -18,14 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import unittest
import numpy as np
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
@test_util.with_c_api
@@ -41,9 +44,10 @@ class SoftmaxTest(test.TestCase):
features, axis=dim), one_only_on_dim))
softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim)
if log:
- return np.log(softmax)
+ res = np.log(softmax)
else:
- return softmax
+ res = softmax
+ return res
def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
# A previous version of the code checked the op name rather than the op type
@@ -53,9 +57,9 @@ class SoftmaxTest(test.TestCase):
np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
with self.test_session(use_gpu=use_gpu):
if log:
- tf_softmax = nn_ops.log_softmax(np_features, dim=dim, name=name)
+ tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name)
else:
- tf_softmax = nn_ops.softmax(np_features, dim=dim, name=name)
+ tf_softmax = nn_ops.softmax(np_features, axis=dim, name=name)
out = tf_softmax.eval()
self.assertAllCloseAccordingToType(np_softmax, out)
self.assertShapeEqual(np_softmax, tf_softmax)
@@ -117,10 +121,32 @@ class SoftmaxTest(test.TestCase):
self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32))
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testFloatGPU(self):
+ if test.is_gpu_available(cuda_only=True):
+ rows = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
+ cols = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
+ for row, col in zip(rows, cols):
+ logging.info("Testing softmax float dtype in shape [%d, %d]", row, col)
+ data = np.random.rand(row, col)
+ self._testAll(data.astype(np.float32))
+
def testHalf(self):
self._testAll(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16))
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testHalfGPU(self):
+ if test.is_gpu_available(cuda_only=True):
+ rows = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
+ cols = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
+ for row, col in zip(rows, cols):
+ logging.info("Testing softmax half dtype in shape [%d, %d]", row, col)
+ data = np.random.rand(row, col)
+ self._testAll(data.astype(np.float16))
+
def testDouble(self):
self._testSoftmax(
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64))
@@ -169,7 +195,7 @@ class SoftmaxTest(test.TestCase):
self.assertEqual(0, array_ops.size(x).eval())
# reshape would raise if logits is empty
with self.assertRaises(errors_impl.InvalidArgumentError):
- nn_ops.softmax(x, dim=0).eval()
+ nn_ops.softmax(x, axis=0).eval()
def testDimTooLarge(self):
with self.test_session():
@@ -177,7 +203,7 @@ class SoftmaxTest(test.TestCase):
# inference error.
dim = array_ops.placeholder_with_default(100, shape=[])
with self.assertRaises(errors_impl.InvalidArgumentError):
- nn_ops.softmax([1., 2., 3., 4.], dim=dim).eval()
+ nn_ops.softmax([1., 2., 3., 4.], axis=dim).eval()
def testLargeDims(self):
# Make sure that we properly handle large inputs. See