diff options
author | Antonio Sanchez <cantonios@google.com> | 2021-01-22 11:10:54 -0800 |
---|---|---|
committer | Antonio Sanchez <cantonios@google.com> | 2021-01-22 11:10:54 -0800 |
commit | f0e46ed5d41eeb450cbcbdb1ce3233d524ad3acd (patch) | |
tree | b2e862ad5f8c0788db4f3c39c0732db64fb5e217 /Eigen | |
parent | f19bcffee6b8018ca101ceb370e6e550a940289f (diff) |
Fix pow and other cwise ops for half/bfloat16.
The new `generic_pow` implementation was failing for half/bfloat16 since
their construction from int/float is not `constexpr`. Modified
in `GenericPacketMathFunctions` to remove `constexpr`.
While adding tests for half/bfloat16, found other issues related to
implicit conversions.
Also needed to implement `numext::arg` for non-integer, non-complex,
non-float/double/long double types. These seem to be implicitly
converted to `std::complex<T>`, which then fails for half/bfloat16.
Diffstat (limited to 'Eigen')
-rw-r--r-- | Eigen/src/Core/MathFunctions.h | 100 | ||||
-rw-r--r-- | Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h | 4 | ||||
-rw-r--r-- | Eigen/src/Core/functors/UnaryFunctors.h | 2 |
3 files changed, 62 insertions, 44 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 511a4276f..e29733c13 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -555,45 +555,63 @@ struct rint_retval ****************************************************************************/ #if EIGEN_HAS_CXX11_MATH - template<typename Scalar> - struct arg_impl { - EIGEN_DEVICE_FUNC - static inline Scalar run(const Scalar& x) - { - #if defined(EIGEN_HIP_DEVICE_COMPILE) - // HIP does not seem to have a native device side implementation for the math routine "arg" - using std::arg; - #else - EIGEN_USING_STD(arg); - #endif - return arg(x); - } - }; -#else - template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex> - struct arg_default_impl +// std::arg is only defined for types of std::complex, or integer types or float/double/long double +template<typename Scalar, + bool HasStdImpl = NumTraits<Scalar>::IsComplex || is_integral<Scalar>::value + || is_same<Scalar, float>::value || is_same<Scalar, double>::value + || is_same<Scalar, long double>::value > +struct arg_default_impl; + +template<typename Scalar> +struct arg_default_impl<Scalar, true> { + EIGEN_DEVICE_FUNC + static inline Scalar run(const Scalar& x) { - typedef typename NumTraits<Scalar>::Real RealScalar; - EIGEN_DEVICE_FUNC - static inline RealScalar run(const Scalar& x) - { - return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); } - }; + #if defined(EIGEN_HIP_DEVICE_COMPILE) + // HIP does not seem to have a native device side implementation for the math routine "arg" + using std::arg; + #else + EIGEN_USING_STD(arg); + #endif + return static_cast<Scalar>(arg(x)); + } +}; - template<typename Scalar> - struct arg_default_impl<Scalar,true> +// Must be non-complex floating-point type (e.g. half/bfloat16). +template<typename Scalar> +struct arg_default_impl<Scalar, false> { + typedef typename NumTraits<Scalar>::Real RealScalar; + EIGEN_DEVICE_FUNC + static inline RealScalar run(const Scalar& x) { - typedef typename NumTraits<Scalar>::Real RealScalar; - EIGEN_DEVICE_FUNC - static inline RealScalar run(const Scalar& x) - { - EIGEN_USING_STD(arg); - return arg(x); - } - }; + return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); + } +}; +#else +template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex> +struct arg_default_impl +{ + typedef typename NumTraits<Scalar>::Real RealScalar; + EIGEN_DEVICE_FUNC + static inline RealScalar run(const Scalar& x) + { + return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); + } +}; - template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {}; +template<typename Scalar> +struct arg_default_impl<Scalar,true> +{ + typedef typename NumTraits<Scalar>::Real RealScalar; + EIGEN_DEVICE_FUNC + static inline RealScalar run(const Scalar& x) + { + EIGEN_USING_STD(arg); + return arg(x); + } +}; #endif +template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {}; template<typename Scalar> struct arg_retval @@ -1425,7 +1443,7 @@ template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log(const T &x) { EIGEN_USING_STD(log); - return log(x); + return static_cast<T>(log(x)); } #if defined(SYCL_DEVICE_ONLY) @@ -1602,7 +1620,7 @@ template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T acosh(const T &x) { EIGEN_USING_STD(acosh); - return acosh(x); + return static_cast<T>(acosh(x)); } #endif @@ -1631,7 +1649,7 @@ template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T asinh(const T &x) { EIGEN_USING_STD(asinh); - return asinh(x); + return static_cast<T>(asinh(x)); } #endif @@ -1652,7 +1670,7 @@ template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T atan(const T &x) { EIGEN_USING_STD(atan); - return atan(x); + return static_cast<T>(atan(x)); } #if EIGEN_HAS_CXX11_MATH @@ -1660,7 +1678,7 @@ template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T atanh(const T &x) { EIGEN_USING_STD(atanh); - return atanh(x); + return static_cast<T>(atanh(x)); } #endif @@ -1682,7 +1700,7 @@ template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T cosh(const T &x) { EIGEN_USING_STD(cosh); - return cosh(x); + return static_cast<T>(cosh(x)); } #if defined(SYCL_DEVICE_ONLY) @@ -1701,7 +1719,7 @@ template<typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T sinh(const T &x) { EIGEN_USING_STD(sinh); - return sinh(x); + return static_cast<T>(sinh(x)); } #if defined(SYCL_DEVICE_ONLY) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 69c92a8cc..e3e91f4ab 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -804,8 +804,8 @@ EIGEN_STRONG_INLINE void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) { typedef typename unpacket_traits<Packet>::type Scalar; EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2; - EIGEN_CONSTEXPR Scalar shift_scale = Scalar(uint64_t(1) << shift); - Packet gamma = pmul(pset1<Packet>(shift_scale + 1), x); + Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr. + Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x); #ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma); #else diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 976ecba59..c98fa573c 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -403,7 +403,7 @@ struct functor_traits<scalar_log10_op<Scalar> > */ template<typename Scalar> struct scalar_log2_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_log2_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * std::log(a); } + EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * numext::log(a); } template <typename Packet> EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog2(a); } }; |