diff options
Diffstat (limited to 'tensorflow/core/util/cuda_kernel_helper.h')
-rw-r--r-- | tensorflow/core/util/cuda_kernel_helper.h | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index a86567a7cc..86c55031c6 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -20,6 +20,7 @@ limitations under the License. #include <algorithm> +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/types.h" #define CUDA_1D_KERNEL_LOOP(i, n) \ @@ -104,6 +105,78 @@ CUDA_ATOMIC_WRAPPER(Add, double) { return __longlong_as_double(old); } +// Helper functions for CudaAtomicAdd(half*, half), below. +// +// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2() +// for a more efficient implementation, assuming that adding -0.0 +// will never harm the neighboring value. In this version, we take special +// care to guarantee the bits of the untouched value are unchanged. +inline __device__ uint32 add_to_low_half(uint32 val, float x) { + Eigen::half low_half; + low_half.x = static_cast<uint16>(val & 0xffffu); + low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x); + return (val & 0xffff0000u) | low_half.x; +} + +inline __device__ uint32 add_to_high_half(uint32 val, float x) { + Eigen::half high_half; + high_half.x = static_cast<uint16>(val >> 16); + high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x); + return (val & 0xffffu) | (high_half.x << 16); +} + +// Custom implementation of atomicAdd for half. Note that we don't have +// atomicCAS() for anything less than 32 bits, so we need to include the +// other 16 bits in the operation. +// +// Unlike the other atomic adds, this version is going to be very slow +// under high concurrency, since most threads will be spinning on failing +// their compare-and-swap tests. (The fact that we get false sharing on the +// neighboring fp16 makes this even worse.) If you are doing a large reduction, +// you are much better off with doing the intermediate steps in fp32 and then +// switching to fp16 as late as you can in the calculations. +// +// Note: Assumes little endian. +CUDA_ATOMIC_WRAPPER(Add, Eigen::half) { + float val_as_float(val); + intptr_t address_int = reinterpret_cast<intptr_t>(address); + if ((address_int & 0x2) == 0) { + // The half is in the first part of the uint32 (lower 16 bits). + uint32* address_as_uint32 = reinterpret_cast<uint32*>(address); + assert(((intptr_t)address_as_uint32 & 0x3) == 0); + uint32 old = *address_as_uint32, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_uint32, assumed, + add_to_low_half(assumed, val_as_float)); + + // Note: uses integer comparison to avoid hang in case of NaN + } while (assumed != old); + + Eigen::half ret; + ret.x = old & 0xffffu; + return ret; + } else { + // The half is in the second part of the uint32 (upper 16 bits). + uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2); + assert(((intptr_t)address_as_uint32 & 0x3) == 0); + uint32 old = *address_as_uint32, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_uint32, assumed, + add_to_high_half(assumed, val_as_float)); + + // Note: uses integer comparison to avoid hang in case of NaN + } while (assumed != old); + + Eigen::half ret; + ret.x = old >> 16; + return ret; + } +} + template <typename T> __global__ void SetZero(const int nthreads, T* bottom_diff) { CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); } |