diff options
-rw-r--r-- | tensorflow/core/kernels/softmax_op_gpu.cu.cc | 90 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 75 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/softmax_op_test.py | 38 |
4 files changed, 145 insertions, 60 deletions
diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc index 1f4a82a733..130d693dbd 100644 --- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc @@ -33,8 +33,42 @@ namespace tensorflow { namespace { +template <typename U, typename T> +__device__ __host__ EIGEN_STRONG_INLINE + typename std::enable_if<!std::is_same<T, U>::value, U>::type + strict_cast(T t); + +template <typename U, typename T> +__device__ __host__ EIGEN_STRONG_INLINE + typename std::enable_if<std::is_same<T, U>::value, U>::type + strict_cast(T t) { + return t; +} + +template <> +__device__ __host__ EIGEN_STRONG_INLINE float strict_cast<float, Eigen::half>( + Eigen::half t) { + return functor::HalfToFloat()(t); +} + +template <> +__device__ __host__ EIGEN_STRONG_INLINE Eigen::half +strict_cast<Eigen::half, float>(float t) { + return functor::FloatToHalf()(t); +} + template <typename T> -__global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs, +struct softmax_traits { + using accumulator_type = T; +}; + +template <> +struct softmax_traits<Eigen::half> { + using accumulator_type = float; +}; + +template <typename T, typename U> +__global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs, const T* max_logits, T* output, const int num_rows, const int num_cols, const bool in_log_space) { @@ -43,25 +77,33 @@ __global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs, const int row = tid / num_cols; const int col = tid % num_cols; + // TODO(jamesqin): change to half2 load when inputs are Eigen::half. + U input = strict_cast<U>(logits[tid]); + U max_val = strict_cast<U>(ldg(max_logits + row)); + U result; + if (row < num_rows && col < num_cols) { - if (in_log_space) - output[tid] = - logits[tid] - ldg(max_logits + row) - log(ldg(sum_probs + row)); - else - output[tid] = - exp(logits[tid] - ldg(max_logits + row)) / ldg(sum_probs + row); + if (in_log_space) { + result = input - max_val - log(ldg(sum_probs + row)); + } else { + result = exp(input - max_val) / ldg(sum_probs + row); + } + output[tid] = strict_cast<T>(result); } } -template <typename T> +template <typename T, typename U> struct SubtractAndExpFunctor { __host__ __device__ SubtractAndExpFunctor(const T* logits, const T* max_logits, const int num_cols) : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {} - __host__ __device__ T operator()(const int gid) const { - return exp(logits_[gid] - ldg(max_logits_ + gid / num_cols_)); + __host__ __device__ U operator()(const int gid) const { + // TODO(jamesqin): change to half2 load when inputs are Eigen::half. + const U diff = + strict_cast<U>(logits_[gid] - ldg(max_logits_ + gid / num_cols_)); + return exp(diff); } const T* logits_; @@ -80,7 +122,6 @@ void DoRowReduction(OpKernelContext* context, T* output, InputIter input, functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>( context, output, input, 2, rows, cols, 1, 1, constants.kOne, op); } - } // namespace template <typename T> @@ -108,8 +149,10 @@ class SoftmaxOpGPU : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value, softmax_out->shape(), &max_logits)); + + typedef typename softmax_traits<T>::accumulator_type acc_type; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, + context->allocate_temp(DataTypeToEnum<acc_type>::value, softmax_out->shape(), &sum_probs)); DoRowReduction<T, cub::Max, const T*>( @@ -120,25 +163,28 @@ class SoftmaxOpGPU : public OpKernel { const int numBlocks = Eigen::divup(rows * cols, numThreads); cub::CountingInputIterator<int> counting_iterator(0); - typedef cub::TransformInputIterator<T, SubtractAndExpFunctor<T>, + typedef cub::TransformInputIterator<acc_type, + SubtractAndExpFunctor<T, acc_type>, cub::CountingInputIterator<int>> InputIterType; InputIterType input_itr( counting_iterator, - SubtractAndExpFunctor<T>( + SubtractAndExpFunctor<T, acc_type>( reinterpret_cast<const T*>(logits_in_.flat<T>().data()), reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols)); - DoRowReduction<T, cub::Sum, InputIterType>( - context, const_cast<T*>(sum_probs.flat<T>().data()), input_itr, rows, - cols); + DoRowReduction<acc_type, cub::Sum, InputIterType>( + context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()), + input_itr, rows, cols); - GenerateNormalizedProb<<<numBlocks, numThreads, 0, cu_stream>>>( - reinterpret_cast<const T*>(logits_in_.flat<T>().data()), - reinterpret_cast<const T*>(sum_probs.flat<T>().data()), - reinterpret_cast<const T*>(max_logits.flat<T>().data()), - const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_); + GenerateNormalizedProb<T, acc_type> + <<<numBlocks, numThreads, 0, cu_stream>>>( + reinterpret_cast<const T*>(logits_in_.flat<T>().data()), + reinterpret_cast<const acc_type*>( + sum_probs.flat<acc_type>().data()), + reinterpret_cast<const T*>(max_logits.flat<T>().data()), + const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_); } } diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index d8f8569939..43106b6e59 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -53,6 +53,7 @@ from tensorflow.python.eager import tape # pylint: disable=unused-import from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -201,6 +202,7 @@ def _strip_checkpoint_v2_randomized(graph_def): def IsGoogleCudaEnabled(): return pywrap_tensorflow.IsGoogleCudaEnabled() + def CudaSupportsHalfMatMulAndConv(): return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv() @@ -335,6 +337,8 @@ def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs): # Make sure default graph reflects prev_value in case next test doesn't call # reset_default_graph(). ops.reset_default_graph() + + # pylint: disable=protected-access @@ -451,7 +455,8 @@ def with_c_api(cls): # If the C API is already enabled, don't do anything. Some tests break if the # same test is run twice, so this allows us to turn on the C API by default # without breaking these tests. - if ops._USE_C_API: return cls + if ops._USE_C_API: + return cls for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): @@ -469,6 +474,7 @@ def assert_no_new_pyobjects_executing_eagerly(f): Useful for checking that there are no missing Py_DECREFs in the C exercised by a bit of Python. """ + def decorator(self, **kwargs): """Warms up, gets an object count, runs the test, checks for new objects.""" with context.eager_mode(): @@ -483,8 +489,10 @@ def assert_no_new_pyobjects_executing_eagerly(f): new_count = len(gc.get_objects()) self.assertEqual(previous_count, new_count) gc.enable() + return decorator + def assert_no_new_tensors(f): """Decorator for asserting that no new Tensors persist after a test. @@ -508,17 +516,15 @@ def assert_no_new_tensors(f): def _is_tensorflow_object(obj): try: - return isinstance(obj, ( - ops.Tensor, - variables.Variable, - tensor_shape.Dimension, - tensor_shape.TensorShape)) + return isinstance(obj, + (ops.Tensor, variables.Variable, + tensor_shape.Dimension, tensor_shape.TensorShape)) except ReferenceError: # If the object no longer exists, we don't care about it. return False - tensors_before = set(id(obj) for obj in gc.get_objects() - if _is_tensorflow_object(obj)) + tensors_before = set( + id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) outside_graph_key = ops.get_default_graph()._graph_key with ops.Graph().as_default(): # Run the test in a new graph so that collections get cleared when it's @@ -572,18 +578,18 @@ def assert_no_garbage_created(f): "likely due to a reference cycle. New objects in cycle(s):") for i, obj in enumerate(gc.garbage[previous_garbage:]): try: - logging.error( - "Object %d of %d" % (i, len(gc.garbage) - previous_garbage)) + logging.error("Object %d of %d", i, + len(gc.garbage) - previous_garbage) + def _safe_object_str(obj): return "<%s %d>" % (obj.__class__.__name__, id(obj)) - logging.error(" Object type: %s" % (_safe_object_str(obj),)) - logging.error(" Referrer types: %s" % ( - ', '.join([_safe_object_str(ref) - for ref in gc.get_referrers(obj)]),)) - logging.error(" Referent types: %s" % ( - ', '.join([_safe_object_str(ref) - for ref in gc.get_referents(obj)]),)) - logging.error(" Object attribute names: %s" % (dir(obj),)) + + logging.error(" Object type: %s", _safe_object_str(obj)) + logging.error(" Referrer types: %s", ", ".join( + [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) + logging.error(" Referent types: %s", ", ".join( + [_safe_object_str(ref) for ref in gc.get_referents(obj)])) + logging.error(" Object attribute names: %s", dir(obj)) logging.error(" Object __str__:") logging.error(obj) logging.error(" Object __repr__:") @@ -705,15 +711,23 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): return 0, 0 return int(match.group(1)), int(match.group(2)) - for local_device in device_lib.list_local_devices(): - if local_device.device_type == "GPU": - if (min_cuda_compute_capability is None or - compute_capability_from_device_desc(local_device.physical_device_desc) - >= min_cuda_compute_capability): + try: + for local_device in device_lib.list_local_devices(): + if local_device.device_type == "GPU": + if (min_cuda_compute_capability is None or + compute_capability_from_device_desc( + local_device.physical_device_desc) >= + min_cuda_compute_capability): + return True + if local_device.device_type == "SYCL" and not cuda_only: return True - if local_device.device_type == "SYCL" and not cuda_only: - return True - return False + return False + except errors_impl.NotFoundError as e: + if not all([x in str(e) for x in ["CUDA", "not find"]]): + raise e + else: + logging.error(str(e)) + return False @contextlib.contextmanager @@ -1256,9 +1270,9 @@ class TensorFlowTestCase(googletest.TestCase): msg="Mismatched value: a%s is different from b%s." % (path_str, path_str)) except TypeError as e: - msg = "Error: a%s has %s, but b%s has %s" % ( - path_str, type(a), path_str, type(b)) - e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:]) + msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a), + path_str, type(b)) + e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) raise def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): @@ -1438,8 +1452,7 @@ class TensorFlowTestCase(googletest.TestCase): """ device1 = pydev.canonical_name(device1) device2 = pydev.canonical_name(device2) - self.assertEqual(device1, device2, - "Devices %s and %s are not equal. %s" % + self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" % (device1, device2, msg)) # Fix Python 3 compatibility issues diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d9571fa2be..ece1da0332 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1910,7 +1910,7 @@ cuda_py_test( cuda_py_test( name = "softmax_op_test", - size = "small", + size = "medium", srcs = ["softmax_op_test.py"], additional_deps = [ "//third_party/py/numpy", diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index 2b8e99e18e..981f96b74d 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -18,14 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import unittest import numpy as np + from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging @test_util.with_c_api @@ -41,9 +44,10 @@ class SoftmaxTest(test.TestCase): features, axis=dim), one_only_on_dim)) softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim) if log: - return np.log(softmax) + res = np.log(softmax) else: - return softmax + res = softmax + return res def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False): # A previous version of the code checked the op name rather than the op type @@ -53,9 +57,9 @@ class SoftmaxTest(test.TestCase): np_softmax = self._npSoftmax(np_features, dim=dim, log=log) with self.test_session(use_gpu=use_gpu): if log: - tf_softmax = nn_ops.log_softmax(np_features, dim=dim, name=name) + tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name) else: - tf_softmax = nn_ops.softmax(np_features, dim=dim, name=name) + tf_softmax = nn_ops.softmax(np_features, axis=dim, name=name) out = tf_softmax.eval() self.assertAllCloseAccordingToType(np_softmax, out) self.assertShapeEqual(np_softmax, tf_softmax) @@ -117,10 +121,32 @@ class SoftmaxTest(test.TestCase): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32)) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testFloatGPU(self): + if test.is_gpu_available(cuda_only=True): + rows = [2**x + np.random.randint(0, 1024) for x in range(1, 10)] + cols = [2**x + np.random.randint(0, 1024) for x in range(1, 10)] + for row, col in zip(rows, cols): + logging.info("Testing softmax float dtype in shape [%d, %d]", row, col) + data = np.random.rand(row, col) + self._testAll(data.astype(np.float32)) + def testHalf(self): self._testAll( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16)) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testHalfGPU(self): + if test.is_gpu_available(cuda_only=True): + rows = [2**x + np.random.randint(0, 1024) for x in range(1, 8)] + cols = [2**x + np.random.randint(0, 1024) for x in range(1, 8)] + for row, col in zip(rows, cols): + logging.info("Testing softmax half dtype in shape [%d, %d]", row, col) + data = np.random.rand(row, col) + self._testAll(data.astype(np.float16)) + def testDouble(self): self._testSoftmax( np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64)) @@ -169,7 +195,7 @@ class SoftmaxTest(test.TestCase): self.assertEqual(0, array_ops.size(x).eval()) # reshape would raise if logits is empty with self.assertRaises(errors_impl.InvalidArgumentError): - nn_ops.softmax(x, dim=0).eval() + nn_ops.softmax(x, axis=0).eval() def testDimTooLarge(self): with self.test_session(): @@ -177,7 +203,7 @@ class SoftmaxTest(test.TestCase): # inference error. dim = array_ops.placeholder_with_default(100, shape=[]) with self.assertRaises(errors_impl.InvalidArgumentError): - nn_ops.softmax([1., 2., 3., 4.], dim=dim).eval() + nn_ops.softmax([1., 2., 3., 4.], axis=dim).eval() def testLargeDims(self): # Make sure that we properly handle large inputs. See |