diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-13 21:48:31 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-13 21:48:31 +0000 |
commit | c6953f799b01d36f4236b64f351cc1446e0abe17 (patch) | |
tree | 9abcded97c6effc010d08787c5b43ef7bb043b54 /Eigen/src/Core | |
parent | 807e51528d220c0efed870f0505dea81a5776085 (diff) |
Add packet generic ops `predux_fmin`, `predux_fmin_nan`, `predux_fmax`, and `predux_fmax_nan` that implement reductions with `PropagateNaN`, and `PropagateNumbers` semantics. Add (slow) generic implementations for most reductions.
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r-- | Eigen/src/Core/GenericPacketMath.h | 380 | ||||
-rw-r--r-- | Eigen/src/Core/functors/BinaryFunctors.h | 42 |
2 files changed, 224 insertions, 198 deletions
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 32cedd0b1..a734d99b7 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -215,74 +215,29 @@ pmul(const bool& a, const bool& b) { return a && b; } template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pdiv(const Packet& a, const Packet& b) { return a/b; } -/** \internal \returns the min of \a a and \a b (coeff-wise). - If \a a or \b b is NaN, the return value is implementation defined. */ +/** \internal \returns one bits */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pmin(const Packet& a, const Packet& b) { return numext::mini(a, b); } +ptrue(const Packet& /*a*/) { Packet b; memset((void*)&b, 0xff, sizeof(b)); return b;} -/** \internal \returns the max of \a a and \a b (coeff-wise) - If \a a or \b b is NaN, the return value is implementation defined. */ +/** \internal \returns zero bits */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pmax(const Packet& a, const Packet& b) { return numext::maxi(a, b); } +pzero(const Packet& /*a*/) { Packet b; memset((void*)&b, 0, sizeof(b)); return b;} -/** \internal \returns the absolute value of \a a */ +/** \internal \returns a <= b as a bit mask */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pabs(const Packet& a) { using std::abs; return abs(a); } -template<> EIGEN_DEVICE_FUNC inline unsigned int -pabs(const unsigned int& a) { return a; } -template<> EIGEN_DEVICE_FUNC inline unsigned long -pabs(const unsigned long& a) { return a; } -template<> EIGEN_DEVICE_FUNC inline unsigned long long -pabs(const unsigned long long& a) { return a; } +pcmp_le(const Packet& a, const Packet& b) { return a<=b ? ptrue(a) : pzero(a); } -/** \internal \returns the phase angle of \a a */ +/** \internal \returns a < b as a bit mask */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -parg(const Packet& a) { using numext::arg; return arg(a); } - - -/** \internal \returns \a a logically shifted by N bits to the right */ -template<int N> EIGEN_DEVICE_FUNC inline int -parithmetic_shift_right(const int& a) { return a >> N; } -template<int N> EIGEN_DEVICE_FUNC inline long int -parithmetic_shift_right(const long int& a) { return a >> N; } - -/** \internal \returns \a a arithmetically shifted by N bits to the right */ -template<int N> EIGEN_DEVICE_FUNC inline int -plogical_shift_right(const int& a) { return static_cast<int>(static_cast<unsigned int>(a) >> N); } -template<int N> EIGEN_DEVICE_FUNC inline long int -plogical_shift_right(const long int& a) { return static_cast<long>(static_cast<unsigned long>(a) >> N); } - -/** \internal \returns \a a shifted by N bits to the left */ -template<int N> EIGEN_DEVICE_FUNC inline int -plogical_shift_left(const int& a) { return a << N; } -template<int N> EIGEN_DEVICE_FUNC inline long int -plogical_shift_left(const long int& a) { return a << N; } - -/** \internal \returns the significant and exponent of the underlying floating point numbers - * See https://en.cppreference.com/w/cpp/numeric/math/frexp - */ -template <typename Packet> -EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) { - int exp; - EIGEN_USING_STD(frexp); - Packet result = frexp(a, &exp); - exponent = static_cast<Packet>(exp); - return result; -} +pcmp_lt(const Packet& a, const Packet& b) { return a<b ? ptrue(a) : pzero(a); } -/** \internal \returns a * 2^exponent - * See https://en.cppreference.com/w/cpp/numeric/math/ldexp - */ +/** \internal \returns a == b as a bit mask */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pldexp(const Packet &a, const Packet &exponent) { - EIGEN_USING_STD(ldexp) - return ldexp(a, static_cast<int>(exponent)); -} +pcmp_eq(const Packet& a, const Packet& b) { return a==b ? ptrue(a) : pzero(a); } -/** \internal \returns zero bits */ +/** \internal \returns a < b or a==NaN or b==NaN as a bit mask */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pzero(const Packet& /*a*/) { Packet b; memset((void*)&b, 0, sizeof(b)); return b;} - +pcmp_lt_or_nan(const Packet& a, const Packet& b) { return a>=b ? pzero(a) : ptrue(a); } template<> EIGEN_DEVICE_FUNC inline float pzero<float>(const float& a) { EIGEN_UNUSED_VARIABLE(a) return 0.f; @@ -293,10 +248,6 @@ template<> EIGEN_DEVICE_FUNC inline double pzero<double>(const double& a) { return 0.; } -/** \internal \returns one bits */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -ptrue(const Packet& /*a*/) { Packet b; memset((void*)&b, 0xff, sizeof(b)); return b;} - template <typename RealScalar> EIGEN_DEVICE_FUNC inline std::complex<RealScalar> ptrue(const std::complex<RealScalar>& /*a*/) { RealScalar b; @@ -341,22 +292,6 @@ pxor(const Packet& a, const Packet& b) { template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pandnot(const Packet& a, const Packet& b) { return pand(a, pxor(ptrue(b), b)); } -/** \internal \returns a <= b as a bit mask */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pcmp_le(const Packet& a, const Packet& b) { return a<=b ? ptrue(a) : pzero(a); } - -/** \internal \returns a < b as a bit mask */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pcmp_lt(const Packet& a, const Packet& b) { return a<b ? ptrue(a) : pzero(a); } - -/** \internal \returns a == b as a bit mask */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pcmp_eq(const Packet& a, const Packet& b) { return a==b ? ptrue(a) : pzero(a); } - -/** \internal \returns a < b or a==NaN or b==NaN as a bit mask */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pcmp_lt_or_nan(const Packet& a, const Packet& b) { return a>=b ? pzero(a) : ptrue(a); } - /** \internal \returns \a or \b for each field in packet according to \mask */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pselect(const Packet& mask, const Packet& a, const Packet& b) { @@ -378,6 +313,119 @@ template<> EIGEN_DEVICE_FUNC inline bool pselect<bool>( return cond ? a : b; } +/** \internal \returns the min or of \a a and \a b (coeff-wise) + If either \a a or \a b are NaN, the result is implementation defined. */ +template<int NaNPropagation> +struct pminmax_impl { + template <typename Packet, typename Op> + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) { + return op(a,b); + } +}; + +/** \internal \returns the min or max of \a a and \a b (coeff-wise) + If either \a a or \a b are NaN, NaN is returned. */ +template<> +struct pminmax_impl<PropagateNaN> { + template <typename Packet, typename Op> + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) { + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet not_nan_mask_b = pcmp_eq(b, b); + return pselect(not_nan_mask_a, + pselect(not_nan_mask_b, op(a, b), b), + a); + } +}; + +/** \internal \returns the min or max of \a a and \a b (coeff-wise) + If both \a a and \a b are NaN, NaN is returned. + Equivalent to std::fmin(a, b). */ +template<> +struct pminmax_impl<PropagateNumbers> { + template <typename Packet, typename Op> + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) { + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet not_nan_mask_b = pcmp_eq(b, b); + return pselect(not_nan_mask_a, + pselect(not_nan_mask_b, op(a, b), a), + b); + } +}; + +/** \internal \returns the min of \a a and \a b (coeff-wise). + If \a a or \b b is NaN, the return value is implementation defined. */ +template<typename Packet> EIGEN_DEVICE_FUNC inline Packet +pmin(const Packet& a, const Packet& b) { return numext::mini(a,b); } + +/** \internal \returns the min of \a a and \a b (coeff-wise). + NaNPropagation determines the NaN propagation semantics. */ +template<int NaNPropagation, typename Packet> EIGEN_DEVICE_FUNC inline Packet +pmin(const Packet& a, const Packet& b) { return pminmax_impl<NaNPropagation>::run(a,b, pmin<Packet>); } + +/** \internal \returns the max of \a a and \a b (coeff-wise) + If \a a or \b b is NaN, the return value is implementation defined. */ +template<typename Packet> EIGEN_DEVICE_FUNC inline Packet +pmax(const Packet& a, const Packet& b) { return numext::maxi(a, b); } + +/** \internal \returns the max of \a a and \a b (coeff-wise). + NaNPropagation determines the NaN propagation semantics. */ +template<int NaNPropagation, typename Packet> EIGEN_DEVICE_FUNC inline Packet +pmax(const Packet& a, const Packet& b) { return pminmax_impl<NaNPropagation>::run(a,b, pmax<Packet>); } + +/** \internal \returns the absolute value of \a a */ +template<typename Packet> EIGEN_DEVICE_FUNC inline Packet +pabs(const Packet& a) { return numext::abs(a); } +template<> EIGEN_DEVICE_FUNC inline unsigned int +pabs(const unsigned int& a) { return a; } +template<> EIGEN_DEVICE_FUNC inline unsigned long +pabs(const unsigned long& a) { return a; } +template<> EIGEN_DEVICE_FUNC inline unsigned long long +pabs(const unsigned long long& a) { return a; } + +/** \internal \returns the phase angle of \a a */ +template<typename Packet> EIGEN_DEVICE_FUNC inline Packet +parg(const Packet& a) { using numext::arg; return arg(a); } + + +/** \internal \returns \a a logically shifted by N bits to the right */ +template<int N> EIGEN_DEVICE_FUNC inline int +parithmetic_shift_right(const int& a) { return a >> N; } +template<int N> EIGEN_DEVICE_FUNC inline long int +parithmetic_shift_right(const long int& a) { return a >> N; } + +/** \internal \returns \a a arithmetically shifted by N bits to the right */ +template<int N> EIGEN_DEVICE_FUNC inline int +plogical_shift_right(const int& a) { return static_cast<int>(static_cast<unsigned int>(a) >> N); } +template<int N> EIGEN_DEVICE_FUNC inline long int +plogical_shift_right(const long int& a) { return static_cast<long>(static_cast<unsigned long>(a) >> N); } + +/** \internal \returns \a a shifted by N bits to the left */ +template<int N> EIGEN_DEVICE_FUNC inline int +plogical_shift_left(const int& a) { return a << N; } +template<int N> EIGEN_DEVICE_FUNC inline long int +plogical_shift_left(const long int& a) { return a << N; } + +/** \internal \returns the significant and exponent of the underlying floating point numbers + * See https://en.cppreference.com/w/cpp/numeric/math/frexp + */ +template <typename Packet> +EIGEN_DEVICE_FUNC inline Packet pfrexp(const Packet& a, Packet& exponent) { + int exp; + EIGEN_USING_STD(frexp); + Packet result = frexp(a, &exp); + exponent = static_cast<Packet>(exp); + return result; +} + +/** \internal \returns a * 2^exponent + * See https://en.cppreference.com/w/cpp/numeric/math/ldexp + */ +template<typename Packet> EIGEN_DEVICE_FUNC inline Packet +pldexp(const Packet &a, const Packet &exponent) { + EIGEN_USING_STD(ldexp) + return ldexp(a, static_cast<int>(exponent)); +} + /** \internal \returns the min of \a a and \a b (coeff-wise) */ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet pabsdiff(const Packet& a, const Packet& b) { return pselect(pcmp_lt(a, b), psub(b, a), psub(a, b)); } @@ -507,57 +555,6 @@ template<typename Scalar> EIGEN_DEVICE_FUNC inline void prefetch(const Scalar* a #endif } -/** \internal \returns the first element of a packet */ -template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type pfirst(const Packet& a) -{ return a; } - -/** \internal \returns the sum of the elements of \a a*/ -template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux(const Packet& a) -{ return a; } - -/** \internal \returns the sum of the elements of upper and lower half of \a a if \a a is larger than 4. - * For a packet {a0, a1, a2, a3, a4, a5, a6, a7}, it returns a half packet {a0+a4, a1+a5, a2+a6, a3+a7} - * For packet-size smaller or equal to 4, this boils down to a noop. - */ -template<typename Packet> EIGEN_DEVICE_FUNC inline -typename conditional<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>::type -predux_half_dowto4(const Packet& a) -{ return a; } - -/** \internal \returns the product of the elements of \a a */ -template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_mul(const Packet& a) -{ return a; } - -/** \internal \returns the min of the elements of \a a */ -template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a) -{ return a; } - -/** \internal \returns the max of the elements of \a a */ -template<typename Packet> EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(const Packet& a) -{ return a; } - -/** \internal \returns true if all coeffs of \a a means "true" - * It is supposed to be called on values returned by pcmp_*. - */ -// not needed yet -// template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_all(const Packet& a) -// { return bool(a); } - -/** \internal \returns true if any coeffs of \a a means "true" - * It is supposed to be called on values returned by pcmp_*. - */ -template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& a) -{ - // Dirty but generic implementation where "true" is assumed to be non 0 and all the sames. - // It is expected that "true" is either: - // - Scalar(1) - // - bits full of ones (NaN for floats), - // - or first bit equals to 1 (1 for ints, smallest denormal for floats). - // For all these cases, taking the sum is just fine, and this boils down to a no-op for scalars. - typedef typename unpacket_traits<Packet>::type Scalar; - return numext::not_equal_strict(predux(a), Scalar(0)); -} - /** \internal \returns the reversed elements of \a a*/ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet preverse(const Packet& a) { return a; } @@ -656,53 +653,104 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); } template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); } +/** \internal \returns the first element of a packet */ +template<typename Packet> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +pfirst(const Packet& a) +{ return a; } -/** \internal \returns the max of \a a and \a b (coeff-wise) - If both \a a and \a b are NaN, NaN is returned. - Equivalent to std::fmax(a, b). */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pfmax(const Packet& a, const Packet& b) { - Packet not_nan_mask_a = pcmp_eq(a, a); - Packet not_nan_mask_b = pcmp_eq(b, b); - return pselect(not_nan_mask_a, - pselect(not_nan_mask_b, pmax(a, b), a), - b); +/** \internal \returns the sum of the elements of upper and lower half of \a a if \a a is larger than 4. + * For a packet {a0, a1, a2, a3, a4, a5, a6, a7}, it returns a half packet {a0+a4, a1+a5, a2+a6, a3+a7} + * For packet-size smaller or equal to 4, this boils down to a noop. + */ +template<typename Packet> +EIGEN_DEVICE_FUNC inline typename conditional<(unpacket_traits<Packet>::size%8)==0,typename unpacket_traits<Packet>::half,Packet>::type +predux_half_dowto4(const Packet& a) +{ return a; } + +// Slow generic implementation of Packet reduction. +template <typename Packet, typename Op> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +predux_helper(const Packet& a, Op op) { + typedef typename unpacket_traits<Packet>::type Scalar; + const size_t n = unpacket_traits<Packet>::size; + Scalar elements[n]; + pstoreu<Scalar>(elements, a); + for(size_t k = n / 2; k > 0; k /= 2) { + for(size_t i = 0; i < k; ++i) { + elements[i] = op(elements[i], elements[i + k]); + } + } + return elements[0]; } -/** \internal \returns the min of \a a and \a b (coeff-wise) - If both \a a and \a b are NaN, NaN is returned. - Equivalent to std::fmin(a, b). */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pfmin(const Packet& a, const Packet& b) { - Packet not_nan_mask_a = pcmp_eq(a, a); - Packet not_nan_mask_b = pcmp_eq(b, b); - return pselect(not_nan_mask_a, - pselect(not_nan_mask_b, pmin(a, b), a), - b); +/** \internal \returns the sum of the elements of \a a*/ +template<typename Packet> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +predux(const Packet& a) +{ + return predux_helper(a, padd<typename unpacket_traits<Packet>::type>); } -/** \internal \returns the max of \a a and \a b (coeff-wise) - If either \a a or \a b are NaN, NaN is returned. */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pfmax_nan(const Packet& a, const Packet& b) { - Packet not_nan_mask_a = pcmp_eq(a, a); - Packet not_nan_mask_b = pcmp_eq(b, b); - return pselect(not_nan_mask_a, - pselect(not_nan_mask_b, pmax(a, b), b), - a); +/** \internal \returns the product of the elements of \a a */ +template<typename Packet> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +predux_mul(const Packet& a) +{ + return predux_helper(a, pmul<typename unpacket_traits<Packet>::type>); } -/** \internal \returns the min of \a a and \a b (coeff-wise) - If either \a a or \a b are NaN, NaN is returned. */ -template<typename Packet> EIGEN_DEVICE_FUNC inline Packet -pfmin_nan(const Packet& a, const Packet& b) { - Packet not_nan_mask_a = pcmp_eq(a, a); - Packet not_nan_mask_b = pcmp_eq(b, b); - return pselect(not_nan_mask_a, - pselect(not_nan_mask_b, pmin(a, b), b), - a); +/** \internal \returns the min of the elements of \a a */ +template<typename Packet> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +predux_min(const Packet& a) +{ + return predux_helper(a, pmin<PropagateFast, typename unpacket_traits<Packet>::type>); } +template<int NaNPropagation, typename Packet> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +predux_min(const Packet& a) +{ + return predux_helper(a, pmin<NaNPropagation, typename unpacket_traits<Packet>::type>); +} + +/** \internal \returns the max of the elements of \a a */ +template<typename Packet> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +predux_max(const Packet& a) +{ + return predux_helper(a, pmax<PropagateFast, typename unpacket_traits<Packet>::type>); +} + +template<int NaNPropagation, typename Packet> +EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type +predux_max(const Packet& a) +{ + return predux_helper(a, pmax<NaNPropagation, typename unpacket_traits<Packet>::type>); +} + +/** \internal \returns true if all coeffs of \a a means "true" + * It is supposed to be called on values returned by pcmp_*. + */ +// not needed yet +// template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_all(const Packet& a) +// { return bool(a); } + +/** \internal \returns true if any coeffs of \a a means "true" + * It is supposed to be called on values returned by pcmp_*. + */ +template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& a) +{ + // Dirty but generic implementation where "true" is assumed to be non 0 and all the sames. + // It is expected that "true" is either: + // - Scalar(1) + // - bits full of ones (NaN for floats), + // - or first bit equals to 1 (1 for ints, smallest denormal for floats). + // For all these cases, taking the sum is just fine, and this boils down to a no-op for scalars. + typedef typename unpacket_traits<Packet>::type Scalar; + return numext::not_equal_strict(predux(a), Scalar(0)); +} /*************************************************************************** * The following functions might not have to be overwritten for vectorized types diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index 55650bb8d..f3509c4b9 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -140,29 +140,18 @@ struct scalar_min_op : binary_op_base<LhsScalar,RhsScalar> typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_min_op>::ReturnType result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_min_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { - if (NaNPropagation == PropagateFast) { - return numext::mini(a, b); - } else if (NaNPropagation == PropagateNumbers) { - return internal::pfmin(a,b); - } else if (NaNPropagation == PropagateNaN) { - return internal::pfmin_nan(a,b); - } + return internal::pmin<NaNPropagation>(a, b); } template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const { - if (NaNPropagation == PropagateFast) { - return internal::pmin(a,b); - } else if (NaNPropagation == PropagateNumbers) { - return internal::pfmin(a,b); - } else if (NaNPropagation == PropagateNaN) { - return internal::pfmin_nan(a,b); - } + return internal::pmin<NaNPropagation>(a,b); } - // TODO(rmlarsen): Handle all NaN propagation semantics reductions. template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const - { return internal::predux_min(a); } + { + return internal::predux_min<NaNPropagation>(a); + } }; template<typename LhsScalar,typename RhsScalar, int NaNPropagation> @@ -184,29 +173,18 @@ struct scalar_max_op : binary_op_base<LhsScalar,RhsScalar> typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_max_op>::ReturnType result_type; EIGEN_EMPTY_STRUCT_CTOR(scalar_max_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { - if (NaNPropagation == PropagateFast) { - return numext::maxi(a, b); - } else if (NaNPropagation == PropagateNumbers) { - return internal::pfmax(a,b); - } else if (NaNPropagation == PropagateNaN) { - return internal::pfmax_nan(a,b); - } + return internal::pmax<NaNPropagation>(a,b); } template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const { - if (NaNPropagation == PropagateFast) { - return internal::pmax(a,b); - } else if (NaNPropagation == PropagateNumbers) { - return internal::pfmax(a,b); - } else if (NaNPropagation == PropagateNaN) { - return internal::pfmax_nan(a,b); - } + return internal::pmax<NaNPropagation>(a,b); } - // TODO(rmlarsen): Handle all NaN propagation semantics reductions. template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const - { return internal::predux_max(a); } + { + return internal::predux_max<NaNPropagation>(a); + } }; template<typename LhsScalar,typename RhsScalar, int NaNPropagation> |