From 19876ced76bd1730008e02fc4c43c2228faed38a Mon Sep 17 00:00:00 2001 From: Ilya Tokar Date: Mon, 16 Dec 2019 16:00:35 -0500 Subject: Bug #1785: Introduce numext::rint. This provides a new op that matches std::rint and previous behavior of pround. Also adds corresponding unsupported/../Tensor op. Performance is the same as e. g. floor (tested SSE/AVX). --- Eigen/src/Core/GenericPacketMath.h | 6 ++++ Eigen/src/Core/GlobalFunctions.h | 1 + Eigen/src/Core/MathFunctions.h | 39 +++++++++++++++++++++++++ Eigen/src/Core/arch/AVX/PacketMath.h | 5 +++- Eigen/src/Core/arch/SSE/PacketMath.h | 5 ++++ Eigen/src/Core/functors/UnaryFunctors.h | 19 ++++++++++++ Eigen/src/plugins/ArrayCwiseUnaryOps.h | 15 ++++++++++ doc/CoeffwiseMathFunctionsTable.dox | 11 +++++++ doc/snippets/Cwise_rint.cpp | 3 ++ test/array_cwise.cpp | 6 ++++ test/packetmath.cpp | 2 ++ unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 6 ++++ 12 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 doc/snippets/Cwise_rint.cpp diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 2e9fd4d7a..146d34fb5 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -92,6 +92,7 @@ struct default_packet_traits HasBetaInc = 0, HasRound = 0, + HasRint = 0, HasFloor = 0, HasCeil = 0, @@ -575,6 +576,11 @@ Packet pround(const Packet& a) { using numext::round; return round(a); } template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pfloor(const Packet& a) { using numext::floor; return floor(a); } +/** \internal \returns the rounded value of \a a (coeff-wise) with current + * rounding mode */ +template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet print(const Packet& a) { using numext::rint; return rint(a); } + /** \internal \returns the ceil of \a a (coeff-wise) */ template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); } diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index 7f132bdd0..8d54f92df 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -89,6 +89,7 @@ namespace Eigen EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rsqrt,scalar_rsqrt_op,reciprocal square root,\sa ArrayBase::rsqrt) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(square,scalar_square_op,square (power 2),\sa Eigen::abs2 DOXCOMMA Eigen::pow DOXCOMMA ArrayBase::square) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cube,scalar_cube_op,cube (power 3),\sa Eigen::pow DOXCOMMA ArrayBase::cube) + EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rint,scalar_rint_op,nearest integer,\sa Eigen::floor DOXCOMMA Eigen::ceil DOXCOMMA ArrayBase::round) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(round,scalar_round_op,nearest integer,\sa Eigen::floor DOXCOMMA Eigen::ceil DOXCOMMA ArrayBase::round) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(floor,scalar_floor_op,nearest integer not greater than the giben value,\sa Eigen::ceil DOXCOMMA ArrayBase::floor) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(ceil,scalar_ceil_op,nearest integer not less than the giben value,\sa Eigen::floor DOXCOMMA ArrayBase::ceil) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index dde329007..42e952c81 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -429,6 +429,38 @@ struct round_retval typedef Scalar type; }; +/**************************************************************************** +* Implementation of rint * +****************************************************************************/ + +template +struct rint_impl { + static inline Scalar run(const Scalar& x) + { + EIGEN_STATIC_ASSERT((!NumTraits::IsComplex), NUMERIC_TYPE_MUST_BE_REAL) +#if EIGEN_HAS_CXX11_MATH + EIGEN_USING_STD_MATH(rint); +#endif + return rint(x); + } +}; + +#if !EIGEN_HAS_CXX11_MATH +template<> +struct rint_impl { + static inline float run(const float& x) + { + return rintf(x); + } +}; +#endif + +template +struct rint_retval +{ + typedef Scalar type; +}; + /**************************************************************************** * Implementation of arg * ****************************************************************************/ @@ -1194,6 +1226,13 @@ SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isinf, isinf, bool) SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isfinite, isfinite, bool) #endif +template +EIGEN_DEVICE_FUNC +inline EIGEN_MATHFUNC_RETVAL(rint, Scalar) rint(const Scalar& x) +{ + return EIGEN_MATHFUNC_IMPL(rint, Scalar)::run(x); +} + template EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(round, Scalar) round(const Scalar& x) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index f83e358ba..11c7bcb43 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -81,7 +81,8 @@ template<> struct packet_traits : default_packet_traits HasBlend = 1, HasRound = 1, HasFloor = 1, - HasCeil = 1 + HasCeil = 1, + HasRint = 1 }; }; template<> struct packet_traits : default_packet_traits @@ -316,6 +317,8 @@ template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8 #endif } +template<> EIGEN_STRONG_INLINE Packet8f print(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet4d print(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); } template<> EIGEN_STRONG_INLINE Packet8f pceil(const Packet8f& a) { return _mm256_ceil_ps(a); } template<> EIGEN_STRONG_INLINE Packet4d pceil(const Packet4d& a) { return _mm256_ceil_pd(a); } diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 2f50326cb..d6a4a5c7f 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -124,6 +124,7 @@ struct packet_traits : default_packet_traits { #ifdef EIGEN_VECTORIZE_SSE4_1 , + HasRint = 1, HasRound = 1, HasCeil = 1 #endif @@ -148,6 +149,7 @@ struct packet_traits : default_packet_traits { #ifdef EIGEN_VECTORIZE_SSE4_1 , HasRound = 1, + HasRint = 1, HasFloor = 1, HasCeil = 1 #endif @@ -443,6 +445,9 @@ template<> EIGEN_STRONG_INLINE Packet2d pround(const Packet2d& a) return _mm_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); } +template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) { return _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) { return _mm_round_pd(a, _MM_FROUND_CUR_DIRECTION); } + template<> EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) { return _mm_ceil_ps(a); } template<> EIGEN_STRONG_INLINE Packet2d pceil(const Packet2d& a) { return _mm_ceil_pd(a); } diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 4e6ce9a32..6d1b6ba51 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -735,6 +735,25 @@ struct functor_traits > }; }; +/** \internal + * \brief Template functor to compute the rounded (with current rounding mode) value of a scalar + * \sa class CwiseUnaryOp, ArrayBase::rint() + */ +template struct scalar_rint_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& a) const { return numext::rint(a); } + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::print(a); } +}; +template +struct functor_traits > +{ + enum { + Cost = NumTraits::MulCost, + PacketAccess = packet_traits::HasRint + }; +}; + /** \internal * \brief Template functor to compute the ceil of a scalar * \sa class CwiseUnaryOp, ArrayBase::ceil() diff --git a/Eigen/src/plugins/ArrayCwiseUnaryOps.h b/Eigen/src/plugins/ArrayCwiseUnaryOps.h index 06ac7aad0..59a4ee6a0 100644 --- a/Eigen/src/plugins/ArrayCwiseUnaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseUnaryOps.h @@ -32,6 +32,7 @@ typedef CwiseUnaryOp, const Derived> CoshReturn typedef CwiseUnaryOp, const Derived> SquareReturnType; typedef CwiseUnaryOp, const Derived> CubeReturnType; typedef CwiseUnaryOp, const Derived> RoundReturnType; +typedef CwiseUnaryOp, const Derived> RintReturnType; typedef CwiseUnaryOp, const Derived> FloorReturnType; typedef CwiseUnaryOp, const Derived> CeilReturnType; typedef CwiseUnaryOp, const Derived> IsNaNReturnType; @@ -427,6 +428,20 @@ cube() const return CubeReturnType(derived()); } +/** \returns an expression of the coefficient-wise rint of *this. + * + * Example: \include Cwise_rint.cpp + * Output: \verbinclude Cwise_rint.out + * + * \sa Math functions, ceil(), floor() + */ +EIGEN_DEVICE_FUNC +inline const RintReturnType +rint() const +{ + return RintReturnType(derived()); +} + /** \returns an expression of the coefficient-wise round of *this. * * Example: \include Cwise_round.cpp diff --git a/doc/CoeffwiseMathFunctionsTable.dox b/doc/CoeffwiseMathFunctionsTable.dox index 8186a5272..ce2f5e097 100644 --- a/doc/CoeffwiseMathFunctionsTable.dox +++ b/doc/CoeffwiseMathFunctionsTable.dox @@ -394,6 +394,17 @@ This also means that, unless specified, if the function \c std::foo is available plus \c using \c std::round ; \cpp11 SSE4,AVX,ZVector (f,d) + + + \anchor cwisetable_rint + a.\link ArrayBase::rint rint\endlink(); \n + \link Eigen::rint rint\endlink(a); + + nearest integer, \n rounding to nearest even in halfway cases + built-in generic implementation using \c std::rint + or \c rintf; + SSE4,AVX (f,d) + Floating point manipulation functions diff --git a/doc/snippets/Cwise_rint.cpp b/doc/snippets/Cwise_rint.cpp new file mode 100644 index 000000000..1dc7b2fd1 --- /dev/null +++ b/doc/snippets/Cwise_rint.cpp @@ -0,0 +1,3 @@ +ArrayXd v = ArrayXd::LinSpaced(7,-2,2); +cout << v << endl << endl; +cout << rint(v) << endl; diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 48ebcc88b..76fc83c33 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -296,6 +296,7 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m1.arg(), arg(m1)); VERIFY_IS_APPROX(m1.round(), round(m1)); + VERIFY_IS_APPROX(m1.rint(), rint(m1)); VERIFY_IS_APPROX(m1.floor(), floor(m1)); VERIFY_IS_APPROX(m1.ceil(), ceil(m1)); VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all()); @@ -331,6 +332,11 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(logistic(m1), (1.0/(1.0+exp(-m1)))); VERIFY_IS_APPROX(arg(m1), ((m1<0).template cast())*std::acos(-1.0)); VERIFY((round(m1) <= ceil(m1) && round(m1) >= floor(m1)).all()); + VERIFY((rint(m1) <= ceil(m1) && rint(m1) >= floor(m1)).all()); + VERIFY(((ceil(m1) - round(m1)) <= Scalar(0.5) || (round(m1) - floor(m1)) <= Scalar(0.5)).all()); + VERIFY(((ceil(m1) - round(m1)) <= Scalar(1.0) && (round(m1) - floor(m1)) <= Scalar(1.0)).all()); + VERIFY(((ceil(m1) - rint(m1)) <= Scalar(0.5) || (rint(m1) - floor(m1)) <= Scalar(0.5)).all()); + VERIFY(((ceil(m1) - rint(m1)) <= Scalar(1.0) && (rint(m1) - floor(m1)) <= Scalar(1.0)).all()); VERIFY((Eigen::isnan)((m1*0.0)/0.0).all()); VERIFY((Eigen::isinf)(m4/0.0).all()); VERIFY(((Eigen::isfinite)(m1) && (!(Eigen::isfinite)(m1*0.0/0.0)) && (!(Eigen::isfinite)(m4/0.0))).all()); diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 9564bc283..ba250443a 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -518,6 +518,7 @@ template void packetmath_real() CHECK_CWISE1_IF(PacketTraits::HasRound, numext::round, internal::pround); CHECK_CWISE1_IF(PacketTraits::HasCeil, numext::ceil, internal::pceil); CHECK_CWISE1_IF(PacketTraits::HasFloor, numext::floor, internal::pfloor); + CHECK_CWISE1_IF(PacketTraits::HasRint, numext::rint, internal::print); // See bug 1785. for (int i=0; i void packetmath_real() data2[i] = -1.5 + i; } CHECK_CWISE1_IF(PacketTraits::HasRound, numext::round, internal::pround); + CHECK_CWISE1_IF(PacketTraits::HasRint, numext::rint, internal::print); for (int i=0; i return unaryExpr(internal::scalar_round_op()); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + rint() const { + return unaryExpr(internal::scalar_rint_op()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> ceil() const { -- cgit v1.2.3