diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h index cf97031be..2714117ab 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h @@ -31,30 +31,34 @@ namespace internal { template <typename T> struct TensorIntDivisor { public: - TensorIntDivisor() { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIntDivisor() { multiplier = 0; shift1 = 0; shift2 = 0; } // Must have 1 <= divider <= 2^31-1 - TensorIntDivisor(const T divider) { - static const int N = 32; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIntDivisor(const T divider) { + const int N = 32; eigen_assert(divider > 0); eigen_assert(divider <= (1<<(N-1)) - 1); // fast ln2 +#ifndef __CUDA_ARCH__ const int leading_zeros = __builtin_clz(divider); - const int l = N - (leading_zeros+1); - - multiplier = (static_cast<uint64_t>(1) << (N+l)) / divider - (static_cast<uint64_t>(1) << N) + 1; - shift1 = (std::min)(1, l); - shift2 = (std::max)(0, l-1); +#else + const int leading_zeros = __clz(divider); +#endif + const int log_div = N - (leading_zeros+1); + + multiplier = (static_cast<uint64_t>(1) << (N+log_div)) / divider - (static_cast<uint64_t>(1) << N) + 1; + shift1 = log_div > 1 ? 1 : log_div; + shift2 = log_div > 1 ? log_div-1 : 0; } // Must have 0 <= numerator <= 2^32-1 - T divide(const T numerator) const { - static const int N = 32; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T divide(const T numerator) const { + const int N = 32; eigen_assert(numerator >= 0); eigen_assert(numerator <= (1ull<<N) - 1); @@ -71,7 +75,7 @@ struct TensorIntDivisor { template <typename T> -static T operator / (const T& numerator, const TensorIntDivisor<T>& divisor) { +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator / (const T& numerator, const TensorIntDivisor<T>& divisor) { return divisor.divide(numerator); } |