diff options
author | Gael Guennebaud <g.gael@free.fr> | 2016-06-23 14:27:20 +0200 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2016-06-23 14:27:20 +0200 |
commit | 76faf4a9657efeed089aeedc98a769410c32d3d7 (patch) | |
tree | bbfe9d39d8e12ae5cd90ac8cbcb69a54ae81c953 | |
parent | 67c12531e567629e84713fbb3150560c916bd08c (diff) |
Introduce a NumTraits<T>::Literal type to be used for literals, and
improve mixing type support in operations between arrays and scalars:
- 2 * ArrayXcf is now optimized in the sense that the integer 2 is properly promoted to a float instead of a complex<float> (fix a regression)
- 2.1 * ArrayXi is now forbiden (previously, 2.1 was converted to 2)
- This mechanism should be applicable to any custom scalar type, assuming NumTraits<T>::Literal is properly defined (it defaults to T)
-rw-r--r-- | Eigen/src/Core/NumTraits.h | 11 | ||||
-rw-r--r-- | Eigen/src/Core/util/Macros.h | 30 | ||||
-rw-r--r-- | Eigen/src/Core/util/XprHelper.h | 28 | ||||
-rw-r--r-- | test/mixingtypes.cpp | 5 | ||||
-rw-r--r-- | test/nesting_ops.cpp | 4 |
5 files changed, 51 insertions, 27 deletions
diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h index e065fa714..03f64a8e9 100644 --- a/Eigen/src/Core/NumTraits.h +++ b/Eigen/src/Core/NumTraits.h @@ -22,14 +22,16 @@ namespace Eigen { * This class stores enums, typedefs and static methods giving information about a numeric type. * * The provided data consists of: - * \li A typedef \a Real, giving the "real part" type of \a T. If \a T is already real, - * then \a Real is just a typedef to \a T. If \a T is \c std::complex<U> then \a Real + * \li A typedef \c Real, giving the "real part" type of \a T. If \a T is already real, + * then \c Real is just a typedef to \a T. If \a T is \c std::complex<U> then \c Real * is a typedef to \a U. - * \li A typedef \a NonInteger, giving the type that should be used for operations producing non-integral values, + * \li A typedef \c NonInteger, giving the type that should be used for operations producing non-integral values, * such as quotients, square roots, etc. If \a T is a floating-point type, then this typedef just gives * \a T again. Note however that many Eigen functions such as internal::sqrt simply refuse to * take integers. Outside of a few cases, Eigen doesn't do automatic type promotion. Thus, this typedef is * only intended as a helper for code that needs to explicitly promote types. + * \li A typedef \c Literal giving the type to use for numeric literals such as "2" or "0.5". For instance, for \c std::complex<U>, Literal is defined as \c U. + * Of course, this type must be fully compatible with \a T. In doubt, just use \a T here. * \li A typedef \a Nested giving the type to use to nest a value inside of the expression tree. If you don't know what * this means, just use \a T here. * \li An enum value \a IsComplex. It is equal to 1 if \a T is a \c std::complex @@ -84,6 +86,7 @@ template<typename T> struct GenericNumTraits T >::type NonInteger; typedef T Nested; + typedef T Literal; EIGEN_DEVICE_FUNC static inline Real epsilon() @@ -145,6 +148,7 @@ template<typename _Real> struct NumTraits<std::complex<_Real> > : GenericNumTraits<std::complex<_Real> > { typedef _Real Real; + typedef typename NumTraits<_Real>::Literal Literal; enum { IsComplex = 1, RequireInitialization = NumTraits<_Real>::RequireInitialization, @@ -168,6 +172,7 @@ struct NumTraits<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > typedef typename NumTraits<Scalar>::NonInteger NonIntegerScalar; typedef Array<NonIntegerScalar, Rows, Cols, Options, MaxRows, MaxCols> NonInteger; typedef ArrayType & Nested; + typedef typename NumTraits<Scalar>::Literal Literal; enum { IsComplex = NumTraits<Scalar>::IsComplex, diff --git a/Eigen/src/Core/util/Macros.h b/Eigen/src/Core/util/Macros.h index 87cc44657..6de21d2bb 100644 --- a/Eigen/src/Core/util/Macros.h +++ b/Eigen/src/Core/util/Macros.h @@ -906,35 +906,21 @@ namespace Eigen { const typename internal::plain_constant_type<EXPR,SCALAR>::type, const EXPR> #define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) \ - EIGEN_DEVICE_FUNC inline \ - const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME) \ - (METHOD)(const Scalar& scalar) const { \ - return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME)(derived(), \ - typename internal::plain_constant_type<Derived,Scalar>::type(derived().rows(), derived().cols(), scalar)); \ - } \ - \ template <typename T> EIGEN_DEVICE_FUNC inline \ - typename internal::enable_if<ScalarBinaryOpTraits<Scalar,T,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<Scalar,T> >::Defined, \ - const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME) >::type \ + const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg<Scalar EIGEN_COMMA T EIGEN_COMMA ScalarBinaryOpTraits<Scalar EIGEN_COMMA T EIGEN_COMMA EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<Scalar EIGEN_COMMA T> >::Defined>::type,OPNAME) \ (METHOD)(const T& scalar) const { \ - return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME)(derived(), \ - typename internal::plain_constant_type<Derived,T>::type(derived().rows(), derived().cols(), scalar)); \ + typedef typename internal::promote_scalar_arg<Scalar,T,ScalarBinaryOpTraits<Scalar,T,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<Scalar,T> >::Defined>::type PromotedT; \ + return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedT,OPNAME)(derived(), \ + typename internal::plain_constant_type<Derived,PromotedT>::type(derived().rows(), derived().cols(), internal::scalar_constant_op<PromotedT>(scalar))); \ } #define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHELEFT(METHOD,OPNAME) \ - EIGEN_DEVICE_FUNC inline friend \ - const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME) \ - (METHOD)(const Scalar& scalar, const StorageBaseType& matrix) { \ - return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME)( \ - typename internal::plain_constant_type<Derived,Scalar>::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \ - } \ - \ template <typename T> EIGEN_DEVICE_FUNC inline friend \ - typename internal::enable_if<ScalarBinaryOpTraits<T,Scalar,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<T,Scalar> >::Defined, \ - const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME) >::type \ + const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename internal::promote_scalar_arg<Scalar EIGEN_COMMA T EIGEN_COMMA ScalarBinaryOpTraits<T EIGEN_COMMA Scalar EIGEN_COMMA EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<T EIGEN_COMMA Scalar> >::Defined>::type,Derived,OPNAME) \ (METHOD)(const T& scalar, const StorageBaseType& matrix) { \ - return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME)( \ - typename internal::plain_constant_type<Derived,T>::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \ + typedef typename internal::promote_scalar_arg<Scalar,T,ScalarBinaryOpTraits<T,Scalar,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<T,Scalar> >::Defined>::type PromotedT; \ + return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(PromotedT,Derived,OPNAME)( \ + typename internal::plain_constant_type<Derived,PromotedT>::type(matrix.derived().rows(), matrix.derived().cols(), internal::scalar_constant_op<PromotedT>(scalar)), matrix.derived()); \ } #define EIGEN_MAKE_SCALAR_BINARY_OP(METHOD,OPNAME) \ diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index c41c408b0..b372ac1ad 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -45,6 +45,34 @@ inline IndexDest convert_index(const IndexSrc& idx) { } +// promote_scalar_arg is an helper used in operation between an expression and a scalar, like: +// expression * scalar +// Its role is to determine how the type T of the scalar operand should be promoted given the scalar type ExprScalar of the given expression. +// The IsSupported template parameter must be provided by the caller as: ScalarBinaryOpTraits<ExprScalar,T,op>::Defined using the proper order for ExprScalar and T. +// Then the logic is as follows: +// - if the operation is natively supported as defined by IsSupported, then the scalar type is not promoted, and T is returned. +// - otherwise, NumTraits<T>::Literal is returned if T is implicitly convertible to NumTraits<T>::Literal AND that this does not imply a float to integer conversion. +// - In all other cases, the promoted type is not defined, and the respective operation is thus invalid and not available (SFINAE). +template<typename ExprScalar,typename T, + bool IsSupported, + bool ConvertibleToLiteral = internal::is_convertible<T,typename NumTraits<ExprScalar>::Literal>::value, + bool IsSafe = NumTraits<T>::IsInteger || !NumTraits<typename NumTraits<ExprScalar>::Literal>::IsInteger> +struct promote_scalar_arg +{ +}; + +template<typename S,typename T, bool ConvertibleToLiteral, bool IsSafe> +struct promote_scalar_arg<S,T,true,ConvertibleToLiteral,IsSafe> +{ + typedef T type; +}; + +template<typename S,typename T> +struct promote_scalar_arg<S,T,false,true,true> +{ + typedef typename NumTraits<S>::Literal type; +}; + //classes inheriting no_assignment_operator don't generate a default operator=. class no_assignment_operator { diff --git a/test/mixingtypes.cpp b/test/mixingtypes.cpp index fe8c16470..57ef85c32 100644 --- a/test/mixingtypes.cpp +++ b/test/mixingtypes.cpp @@ -79,6 +79,11 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType) VERIFY_MIX_SCALAR(vf * scf , vf.template cast<complex<float> >() * scf); VERIFY_MIX_SCALAR(scd * vd , scd * vd.template cast<complex<double> >()); + VERIFY_MIX_SCALAR(vcf * 2 , vcf * complex<float>(2)); + VERIFY_MIX_SCALAR(vcf * 2.1 , vcf * complex<float>(2.1)); + VERIFY_MIX_SCALAR(2 * vcf, vcf * complex<float>(2)); + VERIFY_MIX_SCALAR(2.1 * vcf , vcf * complex<float>(2.1)); + // check scalar quotients VERIFY_MIX_SCALAR(vcf / sf , vcf / complex<float>(sf)); VERIFY_MIX_SCALAR(vf / scf , vf.template cast<complex<float> >() / scf); diff --git a/test/nesting_ops.cpp b/test/nesting_ops.cpp index 2f5025305..a419b0e44 100644 --- a/test/nesting_ops.cpp +++ b/test/nesting_ops.cpp @@ -75,8 +75,8 @@ template <typename MatrixType> void run_nesting_ops_2(const MatrixType& _m) } else { - VERIFY( verify_eval_type<1>(2*m1, 2*m1) ); - VERIFY( verify_eval_type<2>(2*m1, m1) ); + VERIFY( verify_eval_type<2>(2*m1, 2*m1) ); + VERIFY( verify_eval_type<3>(2*m1, m1) ); } VERIFY( verify_eval_type<2>(m1+m1, m1+m1) ); VERIFY( verify_eval_type<3>(m1+m1, m1) ); |