aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-01-22 11:10:54 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2021-01-22 11:10:54 -0800
commitf0e46ed5d41eeb450cbcbdb1ce3233d524ad3acd (patch)
treeb2e862ad5f8c0788db4f3c39c0732db64fb5e217 /Eigen
parentf19bcffee6b8018ca101ceb370e6e550a940289f (diff)
Fix pow and other cwise ops for half/bfloat16.
The new `generic_pow` implementation was failing for half/bfloat16 since their construction from int/float is not `constexpr`. Modified in `GenericPacketMathFunctions` to remove `constexpr`. While adding tests for half/bfloat16, found other issues related to implicit conversions. Also needed to implement `numext::arg` for non-integer, non-complex, non-float/double/long double types. These seem to be implicitly converted to `std::complex<T>`, which then fails for half/bfloat16.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/MathFunctions.h100
-rw-r--r--Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h4
-rw-r--r--Eigen/src/Core/functors/UnaryFunctors.h2
3 files changed, 62 insertions, 44 deletions
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 511a4276f..e29733c13 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -555,45 +555,63 @@ struct rint_retval
****************************************************************************/
#if EIGEN_HAS_CXX11_MATH
- template<typename Scalar>
- struct arg_impl {
- EIGEN_DEVICE_FUNC
- static inline Scalar run(const Scalar& x)
- {
- #if defined(EIGEN_HIP_DEVICE_COMPILE)
- // HIP does not seem to have a native device side implementation for the math routine "arg"
- using std::arg;
- #else
- EIGEN_USING_STD(arg);
- #endif
- return arg(x);
- }
- };
-#else
- template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
- struct arg_default_impl
+// std::arg is only defined for types of std::complex, or integer types or float/double/long double
+template<typename Scalar,
+ bool HasStdImpl = NumTraits<Scalar>::IsComplex || is_integral<Scalar>::value
+ || is_same<Scalar, float>::value || is_same<Scalar, double>::value
+ || is_same<Scalar, long double>::value >
+struct arg_default_impl;
+
+template<typename Scalar>
+struct arg_default_impl<Scalar, true> {
+ EIGEN_DEVICE_FUNC
+ static inline Scalar run(const Scalar& x)
{
- typedef typename NumTraits<Scalar>::Real RealScalar;
- EIGEN_DEVICE_FUNC
- static inline RealScalar run(const Scalar& x)
- {
- return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0); }
- };
+ #if defined(EIGEN_HIP_DEVICE_COMPILE)
+ // HIP does not seem to have a native device side implementation for the math routine "arg"
+ using std::arg;
+ #else
+ EIGEN_USING_STD(arg);
+ #endif
+ return static_cast<Scalar>(arg(x));
+ }
+};
- template<typename Scalar>
- struct arg_default_impl<Scalar,true>
+// Must be non-complex floating-point type (e.g. half/bfloat16).
+template<typename Scalar>
+struct arg_default_impl<Scalar, false> {
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ EIGEN_DEVICE_FUNC
+ static inline RealScalar run(const Scalar& x)
{
- typedef typename NumTraits<Scalar>::Real RealScalar;
- EIGEN_DEVICE_FUNC
- static inline RealScalar run(const Scalar& x)
- {
- EIGEN_USING_STD(arg);
- return arg(x);
- }
- };
+ return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
+ }
+};
+#else
+template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
+struct arg_default_impl
+{
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ EIGEN_DEVICE_FUNC
+ static inline RealScalar run(const Scalar& x)
+ {
+ return (x < Scalar(0)) ? Scalar(EIGEN_PI) : Scalar(0);
+ }
+};
- template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {};
+template<typename Scalar>
+struct arg_default_impl<Scalar,true>
+{
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+ EIGEN_DEVICE_FUNC
+ static inline RealScalar run(const Scalar& x)
+ {
+ EIGEN_USING_STD(arg);
+ return arg(x);
+ }
+};
#endif
+template<typename Scalar> struct arg_impl : arg_default_impl<Scalar> {};
template<typename Scalar>
struct arg_retval
@@ -1425,7 +1443,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T log(const T &x) {
EIGEN_USING_STD(log);
- return log(x);
+ return static_cast<T>(log(x));
}
#if defined(SYCL_DEVICE_ONLY)
@@ -1602,7 +1620,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T acosh(const T &x) {
EIGEN_USING_STD(acosh);
- return acosh(x);
+ return static_cast<T>(acosh(x));
}
#endif
@@ -1631,7 +1649,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T asinh(const T &x) {
EIGEN_USING_STD(asinh);
- return asinh(x);
+ return static_cast<T>(asinh(x));
}
#endif
@@ -1652,7 +1670,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T atan(const T &x) {
EIGEN_USING_STD(atan);
- return atan(x);
+ return static_cast<T>(atan(x));
}
#if EIGEN_HAS_CXX11_MATH
@@ -1660,7 +1678,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T atanh(const T &x) {
EIGEN_USING_STD(atanh);
- return atanh(x);
+ return static_cast<T>(atanh(x));
}
#endif
@@ -1682,7 +1700,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T cosh(const T &x) {
EIGEN_USING_STD(cosh);
- return cosh(x);
+ return static_cast<T>(cosh(x));
}
#if defined(SYCL_DEVICE_ONLY)
@@ -1701,7 +1719,7 @@ template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T sinh(const T &x) {
EIGEN_USING_STD(sinh);
- return sinh(x);
+ return static_cast<T>(sinh(x));
}
#if defined(SYCL_DEVICE_ONLY)
diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
index 69c92a8cc..e3e91f4ab 100644
--- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
+++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h
@@ -804,8 +804,8 @@ EIGEN_STRONG_INLINE
void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
typedef typename unpacket_traits<Packet>::type Scalar;
EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
- EIGEN_CONSTEXPR Scalar shift_scale = Scalar(uint64_t(1) << shift);
- Packet gamma = pmul(pset1<Packet>(shift_scale + 1), x);
+ Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
+ Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma);
#else
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index 976ecba59..c98fa573c 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -403,7 +403,7 @@ struct functor_traits<scalar_log10_op<Scalar> >
*/
template<typename Scalar> struct scalar_log2_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_log2_op)
- EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * std::log(a); }
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return Scalar(EIGEN_LOG2E) * numext::log(a); }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plog2(a); }
};