aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core
diff options
context:
space:
mode:
authorGravatar Ilya Tokar <tokarip@google.com>2019-12-16 16:00:35 -0500
committerGravatar Ilya Tokar <tokarip@google.com>2020-01-07 21:22:44 +0000
commit19876ced76bd1730008e02fc4c43c2228faed38a (patch)
tree22f999134e79f84f6f9dc0d19bdfa6740db416f0 /Eigen/src/Core
parentd0ae052da4ce25a5b4306bfbb5bf8edcd010b663 (diff)
Bug #1785: Introduce numext::rint.
This provides a new op that matches std::rint and previous behavior of pround. Also adds corresponding unsupported/../Tensor op. Performance is the same as e. g. floor (tested SSE/AVX).
Diffstat (limited to 'Eigen/src/Core')
-rw-r--r--Eigen/src/Core/GenericPacketMath.h6
-rw-r--r--Eigen/src/Core/GlobalFunctions.h1
-rw-r--r--Eigen/src/Core/MathFunctions.h39
-rw-r--r--Eigen/src/Core/arch/AVX/PacketMath.h5
-rwxr-xr-xEigen/src/Core/arch/SSE/PacketMath.h5
-rw-r--r--Eigen/src/Core/functors/UnaryFunctors.h19
6 files changed, 74 insertions, 1 deletions
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index 2e9fd4d7a..146d34fb5 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -92,6 +92,7 @@ struct default_packet_traits
HasBetaInc = 0,
HasRound = 0,
+ HasRint = 0,
HasFloor = 0,
HasCeil = 0,
@@ -575,6 +576,11 @@ Packet pround(const Packet& a) { using numext::round; return round(a); }
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pfloor(const Packet& a) { using numext::floor; return floor(a); }
+/** \internal \returns the rounded value of \a a (coeff-wise) with current
+ * rounding mode */
+template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
+Packet print(const Packet& a) { using numext::rint; return rint(a); }
+
/** \internal \returns the ceil of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h
index 7f132bdd0..8d54f92df 100644
--- a/Eigen/src/Core/GlobalFunctions.h
+++ b/Eigen/src/Core/GlobalFunctions.h
@@ -89,6 +89,7 @@ namespace Eigen
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rsqrt,scalar_rsqrt_op,reciprocal square root,\sa ArrayBase::rsqrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(square,scalar_square_op,square (power 2),\sa Eigen::abs2 DOXCOMMA Eigen::pow DOXCOMMA ArrayBase::square)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cube,scalar_cube_op,cube (power 3),\sa Eigen::pow DOXCOMMA ArrayBase::cube)
+ EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rint,scalar_rint_op,nearest integer,\sa Eigen::floor DOXCOMMA Eigen::ceil DOXCOMMA ArrayBase::round)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(round,scalar_round_op,nearest integer,\sa Eigen::floor DOXCOMMA Eigen::ceil DOXCOMMA ArrayBase::round)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(floor,scalar_floor_op,nearest integer not greater than the giben value,\sa Eigen::ceil DOXCOMMA ArrayBase::floor)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(ceil,scalar_ceil_op,nearest integer not less than the giben value,\sa Eigen::floor DOXCOMMA ArrayBase::ceil)
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index dde329007..42e952c81 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -430,6 +430,38 @@ struct round_retval
};
/****************************************************************************
+* Implementation of rint *
+****************************************************************************/
+
+template<typename Scalar>
+struct rint_impl {
+ static inline Scalar run(const Scalar& x)
+ {
+ EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
+#if EIGEN_HAS_CXX11_MATH
+ EIGEN_USING_STD_MATH(rint);
+#endif
+ return rint(x);
+ }
+};
+
+#if !EIGEN_HAS_CXX11_MATH
+template<>
+struct rint_impl<float> {
+ static inline float run(const float& x)
+ {
+ return rintf(x);
+ }
+};
+#endif
+
+template<typename Scalar>
+struct rint_retval
+{
+ typedef Scalar type;
+};
+
+/****************************************************************************
* Implementation of arg *
****************************************************************************/
@@ -1196,6 +1228,13 @@ SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isfinite, isfinite, bool)
template<typename Scalar>
EIGEN_DEVICE_FUNC
+inline EIGEN_MATHFUNC_RETVAL(rint, Scalar) rint(const Scalar& x)
+{
+ return EIGEN_MATHFUNC_IMPL(rint, Scalar)::run(x);
+}
+
+template<typename Scalar>
+EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(round, Scalar) round(const Scalar& x)
{
return EIGEN_MATHFUNC_IMPL(round, Scalar)::run(x);
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index f83e358ba..11c7bcb43 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -81,7 +81,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasBlend = 1,
HasRound = 1,
HasFloor = 1,
- HasCeil = 1
+ HasCeil = 1,
+ HasRint = 1
};
};
template<> struct packet_traits<double> : default_packet_traits
@@ -316,6 +317,8 @@ template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8f print<Packet8f>(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet4d print<Packet4d>(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
template<> EIGEN_STRONG_INLINE Packet8f pceil<Packet8f>(const Packet8f& a) { return _mm256_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { return _mm256_ceil_pd(a); }
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 2f50326cb..d6a4a5c7f 100755
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -124,6 +124,7 @@ struct packet_traits<float> : default_packet_traits {
#ifdef EIGEN_VECTORIZE_SSE4_1
,
+ HasRint = 1,
HasRound = 1,
HasCeil = 1
#endif
@@ -148,6 +149,7 @@ struct packet_traits<double> : default_packet_traits {
#ifdef EIGEN_VECTORIZE_SSE4_1
,
HasRound = 1,
+ HasRint = 1,
HasFloor = 1,
HasCeil = 1
#endif
@@ -443,6 +445,9 @@ template<> EIGEN_STRONG_INLINE Packet2d pround<Packet2d>(const Packet2d& a)
return _mm_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
+template<> EIGEN_STRONG_INLINE Packet4f print<Packet4f>(const Packet4f& a) { return _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); }
+template<> EIGEN_STRONG_INLINE Packet2d print<Packet2d>(const Packet2d& a) { return _mm_round_pd(a, _MM_FROUND_CUR_DIRECTION); }
+
template<> EIGEN_STRONG_INLINE Packet4f pceil<Packet4f>(const Packet4f& a) { return _mm_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet2d pceil<Packet2d>(const Packet2d& a) { return _mm_ceil_pd(a); }
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index 4e6ce9a32..6d1b6ba51 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -736,6 +736,25 @@ struct functor_traits<scalar_floor_op<Scalar> >
};
/** \internal
+ * \brief Template functor to compute the rounded (with current rounding mode) value of a scalar
+ * \sa class CwiseUnaryOp, ArrayBase::rint()
+ */
+template<typename Scalar> struct scalar_rint_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::rint(a); }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::print(a); }
+};
+template<typename Scalar>
+struct functor_traits<scalar_rint_op<Scalar> >
+{
+ enum {
+ Cost = NumTraits<Scalar>::MulCost,
+ PacketAccess = packet_traits<Scalar>::HasRint
+ };
+};
+
+/** \internal
* \brief Template functor to compute the ceil of a scalar
* \sa class CwiseUnaryOp, ArrayBase::ceil()
*/