aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2016-06-23 14:27:20 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2016-06-23 14:27:20 +0200
commit76faf4a9657efeed089aeedc98a769410c32d3d7 (patch)
treebbfe9d39d8e12ae5cd90ac8cbcb69a54ae81c953
parent67c12531e567629e84713fbb3150560c916bd08c (diff)
Introduce a NumTraits<T>::Literal type to be used for literals, and
improve mixing type support in operations between arrays and scalars: - 2 * ArrayXcf is now optimized in the sense that the integer 2 is properly promoted to a float instead of a complex<float> (fix a regression) - 2.1 * ArrayXi is now forbiden (previously, 2.1 was converted to 2) - This mechanism should be applicable to any custom scalar type, assuming NumTraits<T>::Literal is properly defined (it defaults to T)
-rw-r--r--Eigen/src/Core/NumTraits.h11
-rw-r--r--Eigen/src/Core/util/Macros.h30
-rw-r--r--Eigen/src/Core/util/XprHelper.h28
-rw-r--r--test/mixingtypes.cpp5
-rw-r--r--test/nesting_ops.cpp4
5 files changed, 51 insertions, 27 deletions
diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h
index e065fa714..03f64a8e9 100644
--- a/Eigen/src/Core/NumTraits.h
+++ b/Eigen/src/Core/NumTraits.h
@@ -22,14 +22,16 @@ namespace Eigen {
* This class stores enums, typedefs and static methods giving information about a numeric type.
*
* The provided data consists of:
- * \li A typedef \a Real, giving the "real part" type of \a T. If \a T is already real,
- * then \a Real is just a typedef to \a T. If \a T is \c std::complex<U> then \a Real
+ * \li A typedef \c Real, giving the "real part" type of \a T. If \a T is already real,
+ * then \c Real is just a typedef to \a T. If \a T is \c std::complex<U> then \c Real
* is a typedef to \a U.
- * \li A typedef \a NonInteger, giving the type that should be used for operations producing non-integral values,
+ * \li A typedef \c NonInteger, giving the type that should be used for operations producing non-integral values,
* such as quotients, square roots, etc. If \a T is a floating-point type, then this typedef just gives
* \a T again. Note however that many Eigen functions such as internal::sqrt simply refuse to
* take integers. Outside of a few cases, Eigen doesn't do automatic type promotion. Thus, this typedef is
* only intended as a helper for code that needs to explicitly promote types.
+ * \li A typedef \c Literal giving the type to use for numeric literals such as "2" or "0.5". For instance, for \c std::complex<U>, Literal is defined as \c U.
+ * Of course, this type must be fully compatible with \a T. In doubt, just use \a T here.
* \li A typedef \a Nested giving the type to use to nest a value inside of the expression tree. If you don't know what
* this means, just use \a T here.
* \li An enum value \a IsComplex. It is equal to 1 if \a T is a \c std::complex
@@ -84,6 +86,7 @@ template<typename T> struct GenericNumTraits
T
>::type NonInteger;
typedef T Nested;
+ typedef T Literal;
EIGEN_DEVICE_FUNC
static inline Real epsilon()
@@ -145,6 +148,7 @@ template<typename _Real> struct NumTraits<std::complex<_Real> >
: GenericNumTraits<std::complex<_Real> >
{
typedef _Real Real;
+ typedef typename NumTraits<_Real>::Literal Literal;
enum {
IsComplex = 1,
RequireInitialization = NumTraits<_Real>::RequireInitialization,
@@ -168,6 +172,7 @@ struct NumTraits<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
typedef typename NumTraits<Scalar>::NonInteger NonIntegerScalar;
typedef Array<NonIntegerScalar, Rows, Cols, Options, MaxRows, MaxCols> NonInteger;
typedef ArrayType & Nested;
+ typedef typename NumTraits<Scalar>::Literal Literal;
enum {
IsComplex = NumTraits<Scalar>::IsComplex,
diff --git a/Eigen/src/Core/util/Macros.h b/Eigen/src/Core/util/Macros.h
index 87cc44657..6de21d2bb 100644
--- a/Eigen/src/Core/util/Macros.h
+++ b/Eigen/src/Core/util/Macros.h
@@ -906,35 +906,21 @@ namespace Eigen {
const typename internal::plain_constant_type<EXPR,SCALAR>::type, const EXPR>
#define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) \
- EIGEN_DEVICE_FUNC inline \
- const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME) \
- (METHOD)(const Scalar& scalar) const { \
- return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,Scalar,OPNAME)(derived(), \
- typename internal::plain_constant_type<Derived,Scalar>::type(derived().rows(), derived().cols(), scalar)); \
- } \
- \
template <typename T> EIGEN_DEVICE_FUNC inline \
- typename internal::enable_if<ScalarBinaryOpTraits<Scalar,T,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<Scalar,T> >::Defined, \
- const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME) >::type \
+ const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,typename internal::promote_scalar_arg<Scalar EIGEN_COMMA T EIGEN_COMMA ScalarBinaryOpTraits<Scalar EIGEN_COMMA T EIGEN_COMMA EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<Scalar EIGEN_COMMA T> >::Defined>::type,OPNAME) \
(METHOD)(const T& scalar) const { \
- return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,T,OPNAME)(derived(), \
- typename internal::plain_constant_type<Derived,T>::type(derived().rows(), derived().cols(), scalar)); \
+ typedef typename internal::promote_scalar_arg<Scalar,T,ScalarBinaryOpTraits<Scalar,T,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<Scalar,T> >::Defined>::type PromotedT; \
+ return EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(Derived,PromotedT,OPNAME)(derived(), \
+ typename internal::plain_constant_type<Derived,PromotedT>::type(derived().rows(), derived().cols(), internal::scalar_constant_op<PromotedT>(scalar))); \
}
#define EIGEN_MAKE_SCALAR_BINARY_OP_ONTHELEFT(METHOD,OPNAME) \
- EIGEN_DEVICE_FUNC inline friend \
- const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME) \
- (METHOD)(const Scalar& scalar, const StorageBaseType& matrix) { \
- return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar,Derived,OPNAME)( \
- typename internal::plain_constant_type<Derived,Scalar>::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \
- } \
- \
template <typename T> EIGEN_DEVICE_FUNC inline friend \
- typename internal::enable_if<ScalarBinaryOpTraits<T,Scalar,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<T,Scalar> >::Defined, \
- const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME) >::type \
+ const EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(typename internal::promote_scalar_arg<Scalar EIGEN_COMMA T EIGEN_COMMA ScalarBinaryOpTraits<T EIGEN_COMMA Scalar EIGEN_COMMA EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<T EIGEN_COMMA Scalar> >::Defined>::type,Derived,OPNAME) \
(METHOD)(const T& scalar, const StorageBaseType& matrix) { \
- return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(T,Derived,OPNAME)( \
- typename internal::plain_constant_type<Derived,T>::type(matrix.derived().rows(), matrix.derived().cols(), scalar), matrix.derived()); \
+ typedef typename internal::promote_scalar_arg<Scalar,T,ScalarBinaryOpTraits<T,Scalar,EIGEN_CAT(EIGEN_CAT(internal::scalar_,OPNAME),_op)<T,Scalar> >::Defined>::type PromotedT; \
+ return EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(PromotedT,Derived,OPNAME)( \
+ typename internal::plain_constant_type<Derived,PromotedT>::type(matrix.derived().rows(), matrix.derived().cols(), internal::scalar_constant_op<PromotedT>(scalar)), matrix.derived()); \
}
#define EIGEN_MAKE_SCALAR_BINARY_OP(METHOD,OPNAME) \
diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h
index c41c408b0..b372ac1ad 100644
--- a/Eigen/src/Core/util/XprHelper.h
+++ b/Eigen/src/Core/util/XprHelper.h
@@ -45,6 +45,34 @@ inline IndexDest convert_index(const IndexSrc& idx) {
}
+// promote_scalar_arg is an helper used in operation between an expression and a scalar, like:
+// expression * scalar
+// Its role is to determine how the type T of the scalar operand should be promoted given the scalar type ExprScalar of the given expression.
+// The IsSupported template parameter must be provided by the caller as: ScalarBinaryOpTraits<ExprScalar,T,op>::Defined using the proper order for ExprScalar and T.
+// Then the logic is as follows:
+// - if the operation is natively supported as defined by IsSupported, then the scalar type is not promoted, and T is returned.
+// - otherwise, NumTraits<T>::Literal is returned if T is implicitly convertible to NumTraits<T>::Literal AND that this does not imply a float to integer conversion.
+// - In all other cases, the promoted type is not defined, and the respective operation is thus invalid and not available (SFINAE).
+template<typename ExprScalar,typename T,
+ bool IsSupported,
+ bool ConvertibleToLiteral = internal::is_convertible<T,typename NumTraits<ExprScalar>::Literal>::value,
+ bool IsSafe = NumTraits<T>::IsInteger || !NumTraits<typename NumTraits<ExprScalar>::Literal>::IsInteger>
+struct promote_scalar_arg
+{
+};
+
+template<typename S,typename T, bool ConvertibleToLiteral, bool IsSafe>
+struct promote_scalar_arg<S,T,true,ConvertibleToLiteral,IsSafe>
+{
+ typedef T type;
+};
+
+template<typename S,typename T>
+struct promote_scalar_arg<S,T,false,true,true>
+{
+ typedef typename NumTraits<S>::Literal type;
+};
+
//classes inheriting no_assignment_operator don't generate a default operator=.
class no_assignment_operator
{
diff --git a/test/mixingtypes.cpp b/test/mixingtypes.cpp
index fe8c16470..57ef85c32 100644
--- a/test/mixingtypes.cpp
+++ b/test/mixingtypes.cpp
@@ -79,6 +79,11 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType)
VERIFY_MIX_SCALAR(vf * scf , vf.template cast<complex<float> >() * scf);
VERIFY_MIX_SCALAR(scd * vd , scd * vd.template cast<complex<double> >());
+ VERIFY_MIX_SCALAR(vcf * 2 , vcf * complex<float>(2));
+ VERIFY_MIX_SCALAR(vcf * 2.1 , vcf * complex<float>(2.1));
+ VERIFY_MIX_SCALAR(2 * vcf, vcf * complex<float>(2));
+ VERIFY_MIX_SCALAR(2.1 * vcf , vcf * complex<float>(2.1));
+
// check scalar quotients
VERIFY_MIX_SCALAR(vcf / sf , vcf / complex<float>(sf));
VERIFY_MIX_SCALAR(vf / scf , vf.template cast<complex<float> >() / scf);
diff --git a/test/nesting_ops.cpp b/test/nesting_ops.cpp
index 2f5025305..a419b0e44 100644
--- a/test/nesting_ops.cpp
+++ b/test/nesting_ops.cpp
@@ -75,8 +75,8 @@ template <typename MatrixType> void run_nesting_ops_2(const MatrixType& _m)
}
else
{
- VERIFY( verify_eval_type<1>(2*m1, 2*m1) );
- VERIFY( verify_eval_type<2>(2*m1, m1) );
+ VERIFY( verify_eval_type<2>(2*m1, 2*m1) );
+ VERIFY( verify_eval_type<3>(2*m1, m1) );
}
VERIFY( verify_eval_type<2>(m1+m1, m1+m1) );
VERIFY( verify_eval_type<3>(m1+m1, m1) );