aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/Product.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2010-07-11 15:48:30 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2010-07-11 15:48:30 +0200
commitff96c94043d575e4d0dd477c1ed2487e33f79627 (patch)
tree5e9736916779fdacd431c2591a3ec1f77333e505 /Eigen/src/Core/Product.h
parent4161b8be6772f2b7338458c9932d7417797966bb (diff)
mixing types in product step 2:
* pload* and pset1 are now templated on the packet type * gemv routines are now embeded into a structure with a consistent API with respect to gemm * some configurations of vector * matrix and matrix * matrix works fine, some need more work...
Diffstat (limited to 'Eigen/src/Core/Product.h')
-rw-r--r--Eigen/src/Core/Product.h63
1 files changed, 37 insertions, 26 deletions
diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h
index 139132c6b..edd79bd9a 100644
--- a/Eigen/src/Core/Product.h
+++ b/Eigen/src/Core/Product.h
@@ -282,10 +282,13 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
public:
EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct)
+ typedef typename Lhs::Scalar LhsScalar;
+ typedef typename Rhs::Scalar RhsScalar;
+
GeneralProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
{
- EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret),
- YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
+// EIGEN_STATIC_ASSERT((ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret),
+// YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
}
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
@@ -295,7 +298,8 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
{
ei_assert(m_lhs.rows() == dst.rows() && m_rhs.cols() == dst.cols());
ei_gemv_selector<Side,(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor,
- bool(ei_blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*this, dst, alpha);
+ bool(ei_blas_traits<MatrixType>::HasUsableDirectAccess)
+ /*&& ei_is_same_type<typename Lhs::Scalar, typename Rhs::Scalar>::ret*/>::run(*this, dst, alpha);
}
};
@@ -319,43 +323,48 @@ template<> struct ei_gemv_selector<OnTheRight,ColMajor,true>
template<typename ProductType, typename Dest>
static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
{
- typedef typename ProductType::Scalar Scalar;
+ typedef typename ProductType::Index Index;
+ typedef typename ProductType::LhsScalar LhsScalar;
+ typedef typename ProductType::RhsScalar RhsScalar;
+ typedef typename ProductType::Scalar ResScalar;
typedef typename ProductType::ActualLhsType ActualLhsType;
typedef typename ProductType::ActualRhsType ActualRhsType;
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
- typedef Map<Matrix<Scalar,Dynamic,1>, Aligned> MappedDest;
+ typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
- * RhsBlasTraits::extractScalarFactor(prod.rhs());
+ ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
+ * RhsBlasTraits::extractScalarFactor(prod.rhs());
enum {
// FIXME find a way to allow an inner stride on the result if ei_packet_traits<Scalar>::size==1
EvalToDest = Dest::InnerStrideAtCompileTime==1
};
- Scalar* EIGEN_RESTRICT actualDest;
+ ResScalar* actualDest;
if (EvalToDest)
actualDest = &dest.coeffRef(0);
else
{
- actualDest = ei_aligned_stack_new(Scalar,dest.size());
+ actualDest = ei_aligned_stack_new(ResScalar,dest.size());
MappedDest(actualDest, dest.size()) = dest;
}
- ei_cache_friendly_product_colmajor_times_vector
- <LhsBlasTraits::NeedToConjugate,RhsBlasTraits::NeedToConjugate>(
- dest.size(),
- &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.outerStride(),
- actualRhs, actualDest, actualAlpha);
+ ei_general_matrix_vector_product
+ <Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
+ actualLhs.rows(), actualLhs.cols(),
+ &actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.outerStride(),
+ actualRhs, actualRhs.innerStride(),
+ actualDest, 1,
+ actualAlpha);
if (!EvalToDest)
{
dest = MappedDest(actualDest, dest.size());
- ei_aligned_stack_delete(Scalar, actualDest, dest.size());
+ ei_aligned_stack_delete(ResScalar, actualDest, dest.size());
}
}
};
@@ -365,7 +374,9 @@ template<> struct ei_gemv_selector<OnTheRight,RowMajor,true>
template<typename ProductType, typename Dest>
static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
{
- typedef typename ProductType::Scalar Scalar;
+ typedef typename ProductType::LhsScalar LhsScalar;
+ typedef typename ProductType::RhsScalar RhsScalar;
+ typedef typename ProductType::Scalar ResScalar;
typedef typename ProductType::Index Index;
typedef typename ProductType::ActualLhsType ActualLhsType;
typedef typename ProductType::ActualRhsType ActualRhsType;
@@ -376,34 +387,34 @@ template<> struct ei_gemv_selector<OnTheRight,RowMajor,true>
ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
- Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
- * RhsBlasTraits::extractScalarFactor(prod.rhs());
+ ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
+ * RhsBlasTraits::extractScalarFactor(prod.rhs());
enum {
// FIXME I think here we really have to check for ei_packet_traits<Scalar>::size==1
// because in this case it is fine to have an inner stride
- DirectlyUseRhs = ((ei_packet_traits<Scalar>::size==1) || (_ActualRhsType::Flags&ActualPacketAccessBit))
+ DirectlyUseRhs = ((ei_packet_traits<RhsScalar>::size==1) || (_ActualRhsType::Flags&ActualPacketAccessBit))
&& (!(_ActualRhsType::Flags & RowMajorBit))
};
- Scalar* EIGEN_RESTRICT rhs_data;
+ RhsScalar* rhs_data;
if (DirectlyUseRhs)
- rhs_data = reinterpret_cast<Scalar* EIGEN_RESTRICT>(&actualRhs.const_cast_derived().coeffRef(0));
+ rhs_data = &actualRhs.const_cast_derived().coeffRef(0);
else
{
- rhs_data = ei_aligned_stack_new(Scalar, actualRhs.size());
- Map<typename _ActualRhsType::PlainObject>(reinterpret_cast<Scalar*>(rhs_data), actualRhs.size()) = actualRhs;
+ rhs_data = ei_aligned_stack_new(RhsScalar, actualRhs.size());
+ Map<typename _ActualRhsType::PlainObject>(rhs_data, actualRhs.size()) = actualRhs;
}
- ei_cache_friendly_product_rowmajor_times_vector
- <LhsBlasTraits::NeedToConjugate,RhsBlasTraits::NeedToConjugate, Scalar, Index>(
+ ei_general_matrix_vector_product
+ <Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
actualLhs.rows(), actualLhs.cols(),
&actualLhs.const_cast_derived().coeffRef(0,0), actualLhs.outerStride(),
rhs_data, 1,
&dest.coeffRef(0,0), dest.innerStride(),
actualAlpha);
- if (!DirectlyUseRhs) ei_aligned_stack_delete(Scalar, rhs_data, prod.rhs().size());
+ if (!DirectlyUseRhs) ei_aligned_stack_delete(RhsScalar, rhs_data, prod.rhs().size());
}
};