diff options
author | Gael Guennebaud <g.gael@free.fr> | 2016-06-14 14:10:07 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2016-06-14 14:10:07 +0200 |
commit | 396d9cfb6eaa62ecafc6f2d02851f0d2f8d6b4ef (patch) | |
tree | 77779a4ad40d3872751409ec77687c66a3247572 | |
parent | a9bb653a684b55fca1c9f58089f94871484ae50c (diff) |
Generalize expr.pow(scalar), pow(expr,scalar) and pow(scalar,expr).
Internal: scalar_pow_op (unary) is removed, and scalar_binary_pow_op is renamed scalar_pow_op.
-rw-r--r-- | Eigen/src/Core/GlobalFunctions.h | 58 | ||||
-rw-r--r-- | Eigen/src/Core/MathFunctions.h | 2 | ||||
-rw-r--r-- | Eigen/src/Core/functors/BinaryFunctors.h | 39 | ||||
-rw-r--r-- | Eigen/src/Core/util/ForwardDeclarations.h | 3 | ||||
-rw-r--r-- | Eigen/src/plugins/ArrayCwiseBinaryOps.h | 21 | ||||
-rw-r--r-- | Eigen/src/plugins/ArrayCwiseUnaryOps.h | 19 | ||||
-rw-r--r-- | test/mixingtypes.cpp | 47 |
7 files changed, 117 insertions, 72 deletions
diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index e489cefec..60e2ccfed 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -90,13 +90,31 @@ namespace Eigen /** \returns an expression of the coefficient-wise power of \a x to the given constant \a exponent. * + * \tparam ScalarExponent is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression (\c Derived::Scalar). + * * \sa ArrayBase::pow() + * + * \relates ArrayBase */ +#ifdef EIGEN_PARSED_BY_DOXYGEN + template<typename Derived,typename ScalarExponent> + inline const CwiseBinaryOp<internal::scalar_pow_op<Derived::Scalar,ScalarExponent>,Derived,Constant<ScalarExponent> > + pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent); +#else + template<typename Derived,typename ScalarExponent> + inline typename internal::enable_if< !(internal::is_same<typename Derived::Scalar,ScalarExponent>::value) + && ScalarBinaryOpTraits<typename Derived::Scalar,ScalarExponent>::Defined, + const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,ScalarExponent,pow) >::type + pow(const Eigen::ArrayBase<Derived>& x, const ScalarExponent& exponent) { + return x.derived().pow(exponent); + } + template<typename Derived> - inline const Eigen::CwiseUnaryOp<Eigen::internal::scalar_pow_op<typename Derived::Scalar>, const Derived> + inline const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename Derived::Scalar,pow) pow(const Eigen::ArrayBase<Derived>& x, const typename Derived::Scalar& exponent) { return x.derived().pow(exponent); } +#endif /** \returns an expression of the coefficient-wise power of \a x to the given array of \a exponents. * @@ -106,12 +124,14 @@ namespace Eigen * Output: \verbinclude Cwise_array_power_array.out * * \sa ArrayBase::pow() + * + * \relates ArrayBase */ template<typename Derived,typename ExponentDerived> - inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_binary_pow_op<typename Derived::Scalar, typename ExponentDerived::Scalar>, const Derived, const ExponentDerived> + inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_pow_op<typename Derived::Scalar, typename ExponentDerived::Scalar>, const Derived, const ExponentDerived> pow(const Eigen::ArrayBase<Derived>& x, const Eigen::ArrayBase<ExponentDerived>& exponents) { - return Eigen::CwiseBinaryOp<Eigen::internal::scalar_binary_pow_op<typename Derived::Scalar, typename ExponentDerived::Scalar>, const Derived, const ExponentDerived>( + return Eigen::CwiseBinaryOp<Eigen::internal::scalar_pow_op<typename Derived::Scalar, typename ExponentDerived::Scalar>, const Derived, const ExponentDerived>( x.derived(), exponents.derived() ); @@ -120,23 +140,39 @@ namespace Eigen /** \returns an expression of the coefficient-wise power of the scalar \a x to the given array of \a exponents. * * This function computes the coefficient-wise power between a scalar and an array of exponents. - * Beaware that the scalar type of the input scalar \a x and the exponents \a exponents must be the same. + * + * \tparam Scalar is the scalar type of \a x. It must be compatible with the scalar type of the given array expression (\c Derived::Scalar). * * Example: \include Cwise_scalar_power_array.cpp * Output: \verbinclude Cwise_scalar_power_array.out * * \sa ArrayBase::pow() + * + * \relates ArrayBase */ +#ifdef EIGEN_PARSED_BY_DOXYGEN + template<typename Scalar,typename Derived> + inline const CwiseBinaryOp<internal::scalar_pow_op<Scalar,Derived::Scalar>,Constant<Scalar>,Derived> + pow(const Scalar& x,const Eigen::ArrayBase<Derived>& x); +#else + template<typename Scalar, typename Derived> + inline typename internal::enable_if< !(internal::is_same<typename Derived::Scalar,Scalar>::value) + && ScalarBinaryOpTraits<Scalar,typename Derived::Scalar>::Defined, + const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,pow) >::type + pow(const Scalar& x, const Eigen::ArrayBase<Derived>& exponents) + { + return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,pow)( + typename internal::plain_constant_type<Derived,Scalar>::type(exponents.rows(), exponents.cols(), x), exponents.derived() ); + } + template<typename Derived> - inline const Eigen::CwiseBinaryOp<Eigen::internal::scalar_binary_pow_op<typename Derived::Scalar, typename Derived::Scalar>, const typename Derived::ConstantReturnType, const Derived> - pow(const typename Derived::Scalar& x, const Eigen::ArrayBase<Derived>& exponents) + inline const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename Derived::Scalar,Derived,pow) + pow(const typename Derived::Scalar& x, const Eigen::ArrayBase<Derived>& exponents) { - typename Derived::ConstantReturnType constant_x(exponents.rows(), exponents.cols(), x); - return Eigen::CwiseBinaryOp<Eigen::internal::scalar_binary_pow_op<typename Derived::Scalar, typename Derived::Scalar>, const typename Derived::ConstantReturnType, const Derived>( - constant_x, - exponents.derived() - ); + return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename Derived::Scalar,Derived,pow)( + typename internal::plain_constant_type<Derived,typename Derived::Scalar>::type(exponents.rows(), exponents.cols(), x), exponents.derived() ); } +#endif /** * \brief Component-wise division of a scalar by array elements. diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 2a05ae12d..b683effab 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -498,7 +498,7 @@ template<typename ScalarX,typename ScalarY, bool IsInteger = NumTraits<ScalarX>: struct pow_impl { //typedef Scalar retval; - typedef typename ScalarBinaryOpTraits<ScalarX,ScalarY,internal::scalar_binary_pow_op<ScalarX,ScalarY> >::ReturnType result_type; + typedef typename ScalarBinaryOpTraits<ScalarX,ScalarY,internal::scalar_pow_op<ScalarX,ScalarY> >::ReturnType result_type; static EIGEN_DEVICE_FUNC inline result_type run(const ScalarX& x, const ScalarY& y) { EIGEN_USING_STD_MATH(pow); diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index fd7b5f8b4..9e40303be 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -295,16 +295,24 @@ struct functor_traits<scalar_hypot_op<Scalar,Scalar> > { /** \internal * \brief Template functor to compute the pow of two scalars */ -template<typename Scalar, typename OtherScalar> -struct scalar_binary_pow_op : binary_op_base<Scalar,OtherScalar> +template<typename Scalar, typename Exponent> +struct scalar_pow_op : binary_op_base<Scalar,Exponent> { - typedef typename ScalarBinaryOpTraits<Scalar,OtherScalar,scalar_binary_pow_op>::ReturnType result_type; - EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op) + typedef typename ScalarBinaryOpTraits<Scalar,Exponent,scalar_pow_op>::ReturnType result_type; +#ifndef EIGEN_SCALAR_BINARY_OP_PLUGIN + EIGEN_EMPTY_STRUCT_CTOR(scalar_pow_op) +#else + scalar_pow_op() { + typedef Scalar LhsScalar; + typedef Exponent RhsScalar; + EIGEN_SCALAR_BINARY_OP_PLUGIN + } +#endif EIGEN_DEVICE_FUNC - inline result_type operator() (const Scalar& a, const OtherScalar& b) const { return numext::pow(a, b); } + inline result_type operator() (const Scalar& a, const Exponent& b) const { return numext::pow(a, b); } }; -template<typename Scalar, typename OtherScalar> -struct functor_traits<scalar_binary_pow_op<Scalar,OtherScalar> > { +template<typename Scalar, typename Exponent> +struct functor_traits<scalar_pow_op<Scalar,Exponent> > { enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; }; @@ -614,23 +622,6 @@ struct functor_traits<scalar_sub_op<Scalar> > /** \internal - * \brief Template functor to raise a scalar to a power - * \sa class CwiseUnaryOp, Cwise::pow - */ -template<typename Scalar> -struct scalar_pow_op { - // FIXME default copy constructors seems bugged with std::complex<> - EIGEN_DEVICE_FUNC inline scalar_pow_op(const scalar_pow_op& other) : m_exponent(other.m_exponent) { } - EIGEN_DEVICE_FUNC inline scalar_pow_op(const Scalar& exponent) : m_exponent(exponent) {} - EIGEN_DEVICE_FUNC - inline Scalar operator() (const Scalar& a) const { return numext::pow(a, m_exponent); } - const Scalar m_exponent; -}; -template<typename Scalar> -struct functor_traits<scalar_pow_op<Scalar> > -{ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; }; - -/** \internal * \brief Template functor to compute the quotient between a scalar and array entries. * \sa class CwiseUnaryOp, Cwise::inverse() */ diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index c1295dc5f..830f20f90 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -196,7 +196,6 @@ template<typename Scalar> struct scalar_sin_op; template<typename Scalar> struct scalar_acos_op; template<typename Scalar> struct scalar_asin_op; template<typename Scalar> struct scalar_tan_op; -template<typename Scalar> struct scalar_pow_op; template<typename Scalar> struct scalar_inverse_op; template<typename Scalar> struct scalar_square_op; template<typename Scalar> struct scalar_cube_op; @@ -209,10 +208,10 @@ template<typename Scalar> struct scalar_igamma_op; template<typename Scalar> struct scalar_igammac_op; template<typename Scalar> struct scalar_betainc_op; +template<typename Scalar,typename ScalarExponent> struct scalar_pow_op; template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op; template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op; template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_quotient_op; -template<typename ScalarX, typename ScalarY> struct scalar_binary_pow_op; } // end namespace internal diff --git a/Eigen/src/plugins/ArrayCwiseBinaryOps.h b/Eigen/src/plugins/ArrayCwiseBinaryOps.h index 8bf4e9b18..1e20e35d7 100644 --- a/Eigen/src/plugins/ArrayCwiseBinaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseBinaryOps.h @@ -82,7 +82,26 @@ max * Example: \include Cwise_array_power_array.cpp * Output: \verbinclude Cwise_array_power_array.out */ -EIGEN_MAKE_CWISE_BINARY_OP(pow,binary_pow) +EIGEN_MAKE_CWISE_BINARY_OP(pow,pow) + +#ifndef EIGEN_PARSED_BY_DOXYGEN +EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(pow,pow); +#else +/** \returns an expression of the coefficients of \c *this rasied to the constant power \a exponent + * + * \tparam T is the scalar type of \a exponent. It must be compatible with the scalar type of the given expression. + * + * This function computes the coefficient-wise power. The function MatrixBase::pow() in the + * unsupported module MatrixFunctions computes the matrix power. + * + * Example: \include Cwise_pow.cpp + * Output: \verbinclude Cwise_pow.out + * + * \sa ArrayBase::pow(ArrayBase), square(), cube(), exp(), log() + */ +template<typename T> +const CwiseBinaryOp<internal::scalar_pow_op<Scalar,T>,Derived,Constant<Scalar> > pow(const T& exponent) const; +#endif // TODO code generating macros could be moved to Macros.h and could include generation of documentation diff --git a/Eigen/src/plugins/ArrayCwiseUnaryOps.h b/Eigen/src/plugins/ArrayCwiseUnaryOps.h index 775fa6ee0..4a6361d8a 100644 --- a/Eigen/src/plugins/ArrayCwiseUnaryOps.h +++ b/Eigen/src/plugins/ArrayCwiseUnaryOps.h @@ -26,7 +26,6 @@ typedef CwiseUnaryOp<internal::scalar_lgamma_op<Scalar>, const Derived> LgammaRe typedef CwiseUnaryOp<internal::scalar_digamma_op<Scalar>, const Derived> DigammaReturnType; typedef CwiseUnaryOp<internal::scalar_erf_op<Scalar>, const Derived> ErfReturnType; typedef CwiseUnaryOp<internal::scalar_erfc_op<Scalar>, const Derived> ErfcReturnType; -typedef CwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived> PowReturnType; typedef CwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived> SquareReturnType; typedef CwiseUnaryOp<internal::scalar_cube_op<Scalar>, const Derived> CubeReturnType; typedef CwiseUnaryOp<internal::scalar_round_op<Scalar>, const Derived> RoundReturnType; @@ -388,24 +387,6 @@ erfc() const return ErfcReturnType(derived()); } -/** \returns an expression of the coefficient-wise power of *this to the given exponent. - * - * This function computes the coefficient-wise power. The function MatrixBase::pow() in the - * unsupported module MatrixFunctions computes the matrix power. - * - * Example: \include Cwise_pow.cpp - * Output: \verbinclude Cwise_pow.out - * - * \sa exp(), log() - */ -EIGEN_DEVICE_FUNC -inline const PowReturnType -pow(const Scalar& exponent) const -{ - return PowReturnType(derived(), internal::scalar_pow_op<Scalar>(exponent)); -} - - /** \returns an expression of the coefficient-wise inverse of *this. * * Example: \include Cwise_inverse.cpp diff --git a/test/mixingtypes.cpp b/test/mixingtypes.cpp index cf2207114..b38271a17 100644 --- a/test/mixingtypes.cpp +++ b/test/mixingtypes.cpp @@ -23,10 +23,18 @@ #endif +static bool g_called; +#define EIGEN_SCALAR_BINARY_OP_PLUGIN { g_called |= (!internal::is_same<LhsScalar,RhsScalar>::value); } + #include "main.h" using namespace std; +#define VERIFY_MIX_SCALAR(XPR,REF) \ + g_called = false; \ + VERIFY_IS_APPROX(XPR,REF); \ + VERIFY( g_called && #XPR" not properly optimized"); + template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType) { typedef std::complex<float> CF; @@ -66,26 +74,34 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType) #endif // check scalar products - VERIFY_IS_APPROX(vcf * sf , vcf * complex<float>(sf)); - VERIFY_IS_APPROX(sd * vcd , complex<double>(sd) * vcd); - VERIFY_IS_APPROX(vf * scf , vf.template cast<complex<float> >() * scf); - VERIFY_IS_APPROX(scd * vd , scd * vd.template cast<complex<double> >()); + VERIFY_MIX_SCALAR(vcf * sf , vcf * complex<float>(sf)); + VERIFY_MIX_SCALAR(sd * vcd , complex<double>(sd) * vcd); + VERIFY_MIX_SCALAR(vf * scf , vf.template cast<complex<float> >() * scf); + VERIFY_MIX_SCALAR(scd * vd , scd * vd.template cast<complex<double> >()); // check scalar quotients - VERIFY_IS_APPROX(vcf / sf , vcf / complex<float>(sf)); - VERIFY_IS_APPROX(vf / scf , vf.template cast<complex<float> >() / scf); + VERIFY_MIX_SCALAR(vcf / sf , vcf / complex<float>(sf)); + VERIFY_MIX_SCALAR(vf / scf , vf.template cast<complex<float> >() / scf); // check scalar increment - VERIFY_IS_APPROX(vcf.array() + sf , vcf.array() + complex<float>(sf)); - VERIFY_IS_APPROX(sd + vcd.array(), complex<double>(sd) + vcd.array()); - VERIFY_IS_APPROX(vf.array() + scf, vf.template cast<complex<float> >().array() + scf); - VERIFY_IS_APPROX(scd + vd.array() , scd + vd.template cast<complex<double> >().array()); + VERIFY_MIX_SCALAR(vcf.array() + sf , vcf.array() + complex<float>(sf)); + VERIFY_MIX_SCALAR(sd + vcd.array(), complex<double>(sd) + vcd.array()); + VERIFY_MIX_SCALAR(vf.array() + scf, vf.template cast<complex<float> >().array() + scf); + VERIFY_MIX_SCALAR(scd + vd.array() , scd + vd.template cast<complex<double> >().array()); // check scalar subtractions - VERIFY_IS_APPROX(vcf.array() - sf , vcf.array() - complex<float>(sf)); - VERIFY_IS_APPROX(sd - vcd.array(), complex<double>(sd) - vcd.array()); - VERIFY_IS_APPROX(vf.array() - scf, vf.template cast<complex<float> >().array() - scf); - VERIFY_IS_APPROX(scd - vd.array() , scd - vd.template cast<complex<double> >().array()); + VERIFY_MIX_SCALAR(vcf.array() - sf , vcf.array() - complex<float>(sf)); + VERIFY_MIX_SCALAR(sd - vcd.array(), complex<double>(sd) - vcd.array()); + VERIFY_MIX_SCALAR(vf.array() - scf, vf.template cast<complex<float> >().array() - scf); + VERIFY_MIX_SCALAR(scd - vd.array() , scd - vd.template cast<complex<double> >().array()); + + // check scalar powers + VERIFY_MIX_SCALAR( pow(vcf.array(), sf), pow(vcf.array(), complex<float>(sf)) ); + VERIFY_MIX_SCALAR( vcf.array().pow(sf) , pow(vcf.array(), complex<float>(sf)) ); + VERIFY_MIX_SCALAR( pow(sd, vcd.array()), pow(complex<double>(sd), vcd.array()) ); + VERIFY_MIX_SCALAR( pow(vf.array(), scf), pow(vf.template cast<complex<float> >().array(), scf) ); + VERIFY_MIX_SCALAR( vf.array().pow(scf) , pow(vf.template cast<complex<float> >().array(), scf) ); + VERIFY_MIX_SCALAR( pow(scd, vd.array()), pow(scd, vd.template cast<complex<double> >().array()) ); // check dot product vf.dot(vf); @@ -215,6 +231,9 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType) VERIFY_IS_APPROX( md.array().pow(mcd.array()), md.template cast<CD>().eval().array().pow(mcd.array()) ); VERIFY_IS_APPROX( mcd.array().pow(md.array()), mcd.array().pow(md.template cast<CD>().eval().array()) ); + VERIFY_IS_APPROX( pow(md.array(),mcd.array()), md.template cast<CD>().eval().array().pow(mcd.array()) ); + VERIFY_IS_APPROX( pow(mcd.array(),md.array()), mcd.array().pow(md.template cast<CD>().eval().array()) ); + rcd = mcd; VERIFY_IS_APPROX( rcd = md, md.template cast<CD>().eval() ); rcd = mcd; |