aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2009-07-27 11:42:54 +0200
committerGravatar Gael Guennebaud <g.gael@free.fr>2009-07-27 11:42:54 +0200
commitf95b77be6216db7b5448d7d728339cae81129bc9 (patch)
tree813c275adde58d57f1ab0bbc59f5df18a013a613
parent6aba84719d09cf19a43a0e8356b010a0f37f2a5d (diff)
trmm is now fully working and available via TriangularView::operator*
-rw-r--r--Eigen/src/Core/TriangularMatrix.h23
-rw-r--r--Eigen/src/Core/products/TriangularMatrixMatrix.h392
-rw-r--r--Eigen/src/Core/products/TriangularMatrixVector.h116
-rw-r--r--test/CMakeLists.txt11
-rw-r--r--test/product_trmm.cpp69
-rw-r--r--test/product_trmv.cpp (renamed from test/product_triangular.cpp)0
-rw-r--r--test/product_trsm.cpp1
7 files changed, 577 insertions, 35 deletions
diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h
index d2f1e6c28..861b738cb 100644
--- a/Eigen/src/Core/TriangularMatrix.h
+++ b/Eigen/src/Core/TriangularMatrix.h
@@ -142,8 +142,10 @@ struct ei_traits<TriangularView<MatrixType, _Mode> > : ei_traits<MatrixType>
};
};
-template<typename Lhs,typename Rhs>
-struct ei_triangular_vector_product_returntype;
+template<int Mode, bool LhsIsTriangular,
+ typename Lhs, bool LhsIsVector,
+ typename Rhs, bool RhsIsVector>
+struct ei_triangular_product_returntype;
template<typename _MatrixType, unsigned int _Mode> class TriangularView
: public TriangularBase<TriangularView<_MatrixType, _Mode> >
@@ -247,11 +249,24 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
return res;
}
+ /** Efficient triangular matrix times vector/matrix product */
template<typename OtherDerived>
- ei_triangular_vector_product_returntype<TriangularView,OtherDerived>
+ ei_triangular_product_returntype<Mode,true,MatrixType,false,OtherDerived,OtherDerived::IsVectorAtCompileTime>
operator*(const MatrixBase<OtherDerived>& rhs) const
{
- return ei_triangular_vector_product_returntype<TriangularView,OtherDerived>(*this, rhs.derived(), 1);
+ return ei_triangular_product_returntype
+ <Mode,true,MatrixType,false,OtherDerived,OtherDerived::IsVectorAtCompileTime>
+ (m_matrix, rhs.derived());
+ }
+
+ /** Efficient vector/matrix times triangular matrix product */
+ template<typename OtherDerived> friend
+ ei_triangular_product_returntype<Mode,false,OtherDerived,OtherDerived::IsVectorAtCompileTime,MatrixType,false>
+ operator*(const MatrixBase<OtherDerived>& lhs, const TriangularView& rhs)
+ {
+ return ei_triangular_product_returntype
+ <Mode,false,OtherDerived,OtherDerived::IsVectorAtCompileTime,MatrixType,false>
+ (lhs.derived(),rhs.m_matrix);
}
template<typename OtherDerived>
diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix.h b/Eigen/src/Core/products/TriangularMatrixMatrix.h
new file mode 100644
index 000000000..43a4c3d18
--- /dev/null
+++ b/Eigen/src/Core/products/TriangularMatrixMatrix.h
@@ -0,0 +1,392 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2009 Gael Guennebaud <g.gael@free.fr>
+//
+// Eigen is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 3 of the License, or (at your option) any later version.
+//
+// Alternatively, you can redistribute it and/or
+// modify it under the terms of the GNU General Public License as
+// published by the Free Software Foundation; either version 2 of
+// the License, or (at your option) any later version.
+//
+// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
+// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License and a copy of the GNU General Public License along with
+// Eigen. If not, see <http://www.gnu.org/licenses/>.
+
+#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_H
+#define EIGEN_TRIANGULAR_MATRIX_MATRIX_H
+
+// template<typename Scalar, int mr, int StorageOrder, bool Conjugate, int Mode>
+// struct ei_gemm_pack_lhs_triangular
+// {
+// Matrix<Scalar,mr,mr,
+// void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows)
+// {
+// ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
+// ei_const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride);
+// int count = 0;
+// const int peeled_mc = (rows/mr)*mr;
+// for(int i=0; i<peeled_mc; i+=mr)
+// {
+// for(int k=0; k<depth; k++)
+// for(int w=0; w<mr; w++)
+// blockA[count++] = cj(lhs(i+w, k));
+// }
+// for(int i=peeled_mc; i<rows; i++)
+// {
+// for(int k=0; k<depth; k++)
+// blockA[count++] = cj(lhs(i, k));
+// }
+// }
+// };
+
+/* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of
+ * the general matrix matrix product.
+ */
+template <typename Scalar,
+ int Mode, bool LhsIsTriangular,
+ int LhsStorageOrder, bool ConjugateLhs,
+ int RhsStorageOrder, bool ConjugateRhs,
+ int ResStorageOrder>
+struct ei_product_triangular_matrix_matrix;
+
+template <typename Scalar,
+ int Mode, bool LhsIsTriangular,
+ int LhsStorageOrder, bool ConjugateLhs,
+ int RhsStorageOrder, bool ConjugateRhs>
+struct ei_product_triangular_matrix_matrix<Scalar,Mode,LhsIsTriangular,
+ LhsStorageOrder,ConjugateLhs,
+ RhsStorageOrder,ConjugateRhs,RowMajor>
+{
+ static EIGEN_STRONG_INLINE void run(
+ int size, int otherSize,
+ const Scalar* lhs, int lhsStride,
+ const Scalar* rhs, int rhsStride,
+ Scalar* res, int resStride,
+ Scalar alpha)
+ {
+ ei_product_triangular_matrix_matrix<Scalar,
+ (Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
+ (!LhsIsTriangular),
+ RhsStorageOrder==RowMajor ? ColMajor : RowMajor,
+ ConjugateRhs,
+ LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
+ ConjugateLhs,
+ ColMajor>
+ ::run(size, otherSize, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
+ }
+};
+
+// implements col-major += alpha * op(triangular) * op(general)
+template <typename Scalar, int Mode,
+ int LhsStorageOrder, bool ConjugateLhs,
+ int RhsStorageOrder, bool ConjugateRhs>
+struct ei_product_triangular_matrix_matrix<Scalar,Mode,true,
+ LhsStorageOrder,ConjugateLhs,
+ RhsStorageOrder,ConjugateRhs,ColMajor>
+{
+
+ static EIGEN_DONT_INLINE void run(
+ int size, int cols,
+ const Scalar* _lhs, int lhsStride,
+ const Scalar* _rhs, int rhsStride,
+ Scalar* res, int resStride,
+ Scalar alpha)
+ {
+ int rows = size;
+
+ ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
+ ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
+
+ if (ConjugateRhs)
+ alpha = ei_conj(alpha);
+
+ typedef ei_product_blocking_traits<Scalar> Blocking;
+ enum {
+ SmallPanelWidth = EIGEN_ENUM_MAX(Blocking::mr,Blocking::nr),
+ IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
+ };
+
+ int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
+ int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction
+
+ Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
+ Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
+
+ Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,LhsStorageOrder> triangularBuffer;
+ triangularBuffer.setZero();
+ triangularBuffer.diagonal().setOnes();
+
+ ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel;
+ ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
+ ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
+
+ for(int k2=IsLowerTriangular ? size : 0;
+ IsLowerTriangular ? k2>0 : k2<size;
+ IsLowerTriangular ? k2-=kc : k2+=kc)
+ {
+ const int actual_kc = std::min(IsLowerTriangular ? k2 : size-k2, kc);
+ int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2;
+
+ pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, alpha, actual_kc, cols);
+
+ // the selected lhs's panel has to be split in three different parts:
+ // 1 - the part which is above the diagonal block => skip it
+ // 2 - the diagonal block => special kernel
+ // 3 - the panel below the diagonal block => GEPP
+ // the block diagonal
+ {
+ // for each small vertical panels of lhs
+ for (int k1=0; k1<actual_kc; k1+=SmallPanelWidth)
+ {
+ int actualPanelWidth = std::min<int>(actual_kc-k1, SmallPanelWidth);
+ int lengthTarget = IsLowerTriangular ? actual_kc-k1-actualPanelWidth : k1;
+ int startBlock = actual_k2+k1;
+ int blockBOffset = k1;
+
+ // => GEBP with the micro triangular block
+ // The trick is to pack this micro block while filling the opposite triangular part with zeros.
+ // To this end we do an extra triangular copy to small temporary buffer
+ for (int k=0;k<actualPanelWidth;++k)
+ {
+ if (!(Mode&UnitDiagBit))
+ triangularBuffer.coeffRef(k,k) = lhs(startBlock+k,startBlock+k);
+ for (int i=IsLowerTriangular ? k+1 : 0; IsLowerTriangular ? i<actualPanelWidth : i<k; ++i)
+ triangularBuffer.coeffRef(i,k) = lhs(startBlock+i,startBlock+k);
+ }
+ pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.stride(), actualPanelWidth, actualPanelWidth);
+
+ gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols,
+ actualPanelWidth, actual_kc, 0, blockBOffset*Blocking::PacketSize);
+
+ // GEBP with remaining micro panel
+ if (lengthTarget>0)
+ {
+ int startTarget = IsLowerTriangular ? actual_k2+k1+actualPanelWidth : actual_k2;
+
+ pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget);
+
+ gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols,
+ actualPanelWidth, actual_kc, 0, blockBOffset*Blocking::PacketSize);
+ }
+ }
+ }
+ // the part below the diagonal => GEPP
+ {
+ int start = IsLowerTriangular ? k2 : 0;
+ int end = IsLowerTriangular ? size : actual_k2;
+ for(int i2=start; i2<end; i2+=mc)
+ {
+ const int actual_mc = std::min(i2+mc,end)-i2;
+ ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder,false>()
+ (blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
+
+ gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols);
+ }
+ }
+ }
+
+ ei_aligned_stack_delete(Scalar, blockA, kc*mc);
+ ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize);
+ }
+};
+
+// implements col-major += alpha * op(general) * op(triangular)
+template <typename Scalar, int Mode,
+ int LhsStorageOrder, bool ConjugateLhs,
+ int RhsStorageOrder, bool ConjugateRhs>
+struct ei_product_triangular_matrix_matrix<Scalar,Mode,false,
+ LhsStorageOrder,ConjugateLhs,
+ RhsStorageOrder,ConjugateRhs,ColMajor>
+{
+
+ static EIGEN_DONT_INLINE void run(
+ int size, int rows,
+ const Scalar* _lhs, int lhsStride,
+ const Scalar* _rhs, int rhsStride,
+ Scalar* res, int resStride,
+ Scalar alpha)
+ {
+ int cols = size;
+
+ ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
+ ei_const_blas_data_mapper<Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
+
+ if (ConjugateRhs)
+ alpha = ei_conj(alpha);
+
+ typedef ei_product_blocking_traits<Scalar> Blocking;
+ enum {
+ SmallPanelWidth = EIGEN_ENUM_MAX(Blocking::mr,Blocking::nr),
+ IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
+ };
+
+ int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
+ int mc = std::min<int>(Blocking::Max_mc,rows); // cache block size along the M direction
+
+ Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
+ Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
+
+ Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,RhsStorageOrder> triangularBuffer;
+ triangularBuffer.setZero();
+ triangularBuffer.diagonal().setOnes();
+
+ ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,ConjugateRhs> > gebp_kernel;
+ ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
+ ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
+ ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
+
+ for(int k2=IsLowerTriangular ? 0 : size;
+ IsLowerTriangular ? k2<size : k2>0;
+ IsLowerTriangular ? k2+=kc : k2-=kc)
+ {
+ const int actual_kc = std::min(IsLowerTriangular ? size-k2 : k2, kc);
+ int actual_k2 = IsLowerTriangular ? k2 : k2-actual_kc;
+ int rs = IsLowerTriangular ? actual_k2 : size - k2;
+ Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize;
+
+ pack_rhs(geb, &rhs(actual_k2,IsLowerTriangular ? 0 : k2), rhsStride, alpha, actual_kc, rs);
+
+ // pack the triangular part of the rhs padding the unrolled blocks with zeros
+ {
+ for (int j2=0; j2<actual_kc; j2+=SmallPanelWidth)
+ {
+ int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
+ int actual_j2 = actual_k2 + j2;
+ int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
+ int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
+ // general part
+ pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
+ &rhs(actual_k2+panelOffset, actual_j2), rhsStride, alpha,
+ panelLength, actualPanelWidth,
+ actual_kc, panelOffset);
+
+ // append the triangular part via a temporary buffer
+ for (int j=0;j<actualPanelWidth;++j)
+ {
+ if (!(Mode&UnitDiagBit))
+ triangularBuffer.coeffRef(j,j) = rhs(actual_j2+j,actual_j2+j);
+ for (int k=IsLowerTriangular ? j+1 : 0; IsLowerTriangular ? k<actualPanelWidth : k<j; ++k)
+ triangularBuffer.coeffRef(k,j) = rhs(actual_j2+k,actual_j2+j);
+ }
+
+ pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
+ triangularBuffer.data(), triangularBuffer.stride(), alpha,
+ actualPanelWidth, actualPanelWidth,
+ actual_kc, j2);
+ }
+ }
+
+ for (int i2=0; i2<rows; i2+=mc)
+ {
+ const int actual_mc = std::min(mc,rows-i2);
+ pack_lhs(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
+
+ // triangular kernel
+ {
+ for (int j2=0; j2<actual_kc; j2+=SmallPanelWidth)
+ {
+ int actualPanelWidth = std::min<int>(actual_kc-j2, SmallPanelWidth);
+ int panelLength = IsLowerTriangular ? actual_kc-j2 : j2+actualPanelWidth;
+ int blockOffset = IsLowerTriangular ? j2 : 0;
+
+ gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride,
+ blockA, blockB+j2*actual_kc*Blocking::PacketSize,
+ actual_mc, panelLength, actualPanelWidth,
+ actual_kc, actual_kc, // strides
+ blockOffset, blockOffset*Blocking::PacketSize);// offsets
+ }
+ }
+ gebp_kernel(res+i2+(IsLowerTriangular ? 0 : k2)*resStride, resStride,
+ blockA, geb, actual_mc, actual_kc, rs);
+ }
+ }
+
+ ei_aligned_stack_delete(Scalar, blockA, kc*mc);
+ ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize);
+ }
+};
+
+/***************************************************************************
+* Wrapper to ei_product_triangular_matrix_matrix
+***************************************************************************/
+
+template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
+struct ei_triangular_product_returntype<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
+ : public ReturnByValue<ei_triangular_product_returntype<Mode,LhsIsTriangular,Lhs,false,Rhs,false>,
+ Matrix<typename ei_traits<Rhs>::Scalar,
+ Lhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
+{
+ ei_triangular_product_returntype(const Lhs& lhs, const Rhs& rhs)
+ : m_lhs(lhs), m_rhs(rhs)
+ {}
+
+ typedef typename Lhs::Scalar Scalar;
+
+ typedef typename Lhs::Nested LhsNested;
+ typedef typename ei_cleantype<LhsNested>::type _LhsNested;
+ typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
+
+ typedef typename Rhs::Nested RhsNested;
+ typedef typename ei_cleantype<RhsNested>::type _RhsNested;
+ typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
+
+// enum {
+// LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit),
+// LhsIsTriangular = (LhsMode&SelfAdjointBit)==SelfAdjointBit,
+// RhsUpLo = RhsMode&(UpperTriangularBit|LowerTriangularBit),
+// RhsIsSelfAdjoint = (RhsMode&SelfAdjointBit)==SelfAdjointBit
+// };
+
+ template<typename Dest> inline void _addTo(Dest& dst) const
+ { evalTo(dst,1); }
+ template<typename Dest> inline void _subTo(Dest& dst) const
+ { evalTo(dst,-1); }
+
+ template<typename Dest> void evalTo(Dest& dst) const
+ {
+ dst.resize(m_lhs.rows(), m_rhs.cols());
+ dst.setZero();
+ evalTo(dst,1);
+ }
+
+ template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
+ {
+ const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
+ const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
+
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
+ * RhsBlasTraits::extractScalarFactor(m_rhs);
+
+ ei_product_triangular_matrix_matrix<Scalar,
+ Mode, LhsIsTriangular,
+ (ei_traits<_ActualLhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
+ (ei_traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
+ (ei_traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
+ ::run(
+ lhs.rows(), LhsIsTriangular ? rhs.cols() : lhs.rows(), // sizes
+ &lhs.coeff(0,0), lhs.stride(), // lhs info
+ &rhs.coeff(0,0), rhs.stride(), // rhs info
+ &dst.coeffRef(0,0), dst.stride(), // result info
+ actualAlpha // alpha
+ );
+ }
+
+ const LhsNested m_lhs;
+ const RhsNested m_rhs;
+};
+
+#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H
diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h
index 533aad170..0fbbb50d2 100644
--- a/Eigen/src/Core/products/TriangularMatrixVector.h
+++ b/Eigen/src/Core/products/TriangularMatrixVector.h
@@ -113,49 +113,113 @@ struct ei_product_triangular_vector_selector<Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs
}
};
-template<typename Lhs,typename Rhs>
-struct ei_triangular_vector_product_returntype
- : public ReturnByValue<ei_triangular_vector_product_returntype<Lhs,Rhs>,
+// template<typename Lhs,typename Rhs>
+// struct ei_triangular_vector_product_returntype
+// : public ReturnByValue<ei_triangular_vector_product_returntype<Lhs,Rhs>,
+// Matrix<typename ei_traits<Rhs>::Scalar,
+// Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
+// {
+// typedef typename Lhs::Scalar Scalar;
+// typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
+// ei_triangular_vector_product_returntype(const Lhs& lhs, const Rhs& rhs, Scalar alpha)
+// : m_lhs(lhs), m_rhs(rhs), m_alpha(alpha)
+// {}
+//
+// template<typename Dest> void evalTo(Dest& dst) const
+// {
+// typedef typename Lhs::MatrixType MatrixType;
+//
+// typedef ei_blas_traits<MatrixType> LhsBlasTraits;
+// typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+// typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
+// const ActualLhsType actualLhs = LhsBlasTraits::extract(m_lhs._expression());
+//
+// typedef ei_blas_traits<Rhs> RhsBlasTraits;
+// typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+// typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
+// const ActualRhsType actualRhs = RhsBlasTraits::extract(m_rhs);
+//
+// Scalar actualAlpha = m_alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
+// * RhsBlasTraits::extractScalarFactor(m_rhs);
+//
+// dst.resize(m_rhs.rows(), m_rhs.cols());
+// dst.setZero();
+// ei_product_triangular_vector_selector
+// <_ActualLhsType,_ActualRhsType,Dest,
+// ei_traits<Lhs>::Mode,
+// LhsBlasTraits::NeedToConjugate,
+// RhsBlasTraits::NeedToConjugate,
+// ei_traits<Lhs>::Flags&RowMajorBit>
+// ::run(actualLhs,actualRhs,dst,actualAlpha);
+// }
+//
+// const Lhs m_lhs;
+// const typename Rhs::Nested m_rhs;
+// const Scalar m_alpha;
+// };
+
+
+/***************************************************************************
+* Wrapper to ei_product_triangular_vector
+***************************************************************************/
+
+template<int Mode, /*bool LhsIsTriangular, */typename Lhs, typename Rhs>
+struct ei_triangular_product_returntype<Mode,true,Lhs,false,Rhs,true>
+ : public ReturnByValue<ei_triangular_product_returntype<Mode,true,Lhs,false,Rhs,true>,
Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{
typedef typename Lhs::Scalar Scalar;
- typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
- ei_triangular_vector_product_returntype(const Lhs& lhs, const Rhs& rhs, Scalar alpha)
- : m_lhs(lhs), m_rhs(rhs), m_alpha(alpha)
+
+ typedef typename Lhs::Nested LhsNested;
+ typedef typename ei_cleantype<LhsNested>::type _LhsNested;
+ typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
+ typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
+ typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
+
+ typedef typename Rhs::Nested RhsNested;
+ typedef typename ei_cleantype<RhsNested>::type _RhsNested;
+ typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
+ typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
+ typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
+
+ ei_triangular_product_returntype(const Lhs& lhs, const Rhs& rhs)
+ : m_lhs(lhs), m_rhs(rhs)
{}
+ template<typename Dest> inline void _addTo(Dest& dst) const
+ { evalTo(dst,1); }
+ template<typename Dest> inline void _subTo(Dest& dst) const
+ { evalTo(dst,-1); }
+
template<typename Dest> void evalTo(Dest& dst) const
{
- typedef typename Lhs::MatrixType MatrixType;
-
- typedef ei_blas_traits<MatrixType> LhsBlasTraits;
- typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
- typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
- const ActualLhsType actualLhs = LhsBlasTraits::extract(m_lhs._expression());
-
- typedef ei_blas_traits<Rhs> RhsBlasTraits;
- typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
- typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
- const ActualRhsType actualRhs = RhsBlasTraits::extract(m_rhs);
-
- Scalar actualAlpha = m_alpha * LhsBlasTraits::extractScalarFactor(m_lhs._expression())
- * RhsBlasTraits::extractScalarFactor(m_rhs);
-
+ dst.resize(m_lhs.rows(), m_rhs.cols());
+ dst.setZero();
+ evalTo(dst,1);
+ }
+
+ template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
+ {
+ const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
+ const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
+
+ Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
+ * RhsBlasTraits::extractScalarFactor(m_rhs);
+
dst.resize(m_rhs.rows(), m_rhs.cols());
dst.setZero();
ei_product_triangular_vector_selector
<_ActualLhsType,_ActualRhsType,Dest,
- ei_traits<Lhs>::Mode,
+ Mode,
LhsBlasTraits::NeedToConjugate,
RhsBlasTraits::NeedToConjugate,
ei_traits<Lhs>::Flags&RowMajorBit>
- ::run(actualLhs,actualRhs,dst,actualAlpha);
+ ::run(lhs,rhs,dst,actualAlpha);
}
- const Lhs m_lhs;
- const typename Rhs::Nested m_rhs;
- const Scalar m_alpha;
+ const LhsNested m_lhs;
+ const RhsNested m_rhs;
};
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 462032453..99224ff60 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -98,10 +98,6 @@ ei_add_test(redux)
ei_add_test(product_small)
ei_add_test(product_large ${EI_OFLAG})
ei_add_test(product_extra ${EI_OFLAG})
-ei_add_test(product_selfadjoint ${EI_OFLAG})
-ei_add_test(product_symm ${EI_OFLAG})
-ei_add_test(product_syrk ${EI_OFLAG})
-ei_add_test(product_trsm ${EI_OFLAG})
ei_add_test(diagonalmatrices)
ei_add_test(adjoint)
ei_add_test(submatrices)
@@ -113,7 +109,12 @@ ei_add_test(array)
ei_add_test(array_replicate)
ei_add_test(array_reverse)
ei_add_test(triangular)
-ei_add_test(product_triangular)
+ei_add_test(product_selfadjoint ${EI_OFLAG})
+ei_add_test(product_symm ${EI_OFLAG})
+ei_add_test(product_syrk ${EI_OFLAG})
+ei_add_test(product_trmv ${EI_OFLAG})
+ei_add_test(product_trmm ${EI_OFLAG})
+ei_add_test(product_trsm ${EI_OFLAG})
ei_add_test(bandmatrix)
ei_add_test(cholesky " " "${GSL_LIBRARIES}")
ei_add_test(lu ${EI_OFLAG})
diff --git a/test/product_trmm.cpp b/test/product_trmm.cpp
new file mode 100644
index 000000000..47ffb4af3
--- /dev/null
+++ b/test/product_trmm.cpp
@@ -0,0 +1,69 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@gmail.com>
+//
+// Eigen is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 3 of the License, or (at your option) any later version.
+//
+// Alternatively, you can redistribute it and/or
+// modify it under the terms of the GNU General Public License as
+// published by the Free Software Foundation; either version 2 of
+// the License, or (at your option) any later version.
+//
+// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
+// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License and a copy of the GNU General Public License along with
+// Eigen. If not, see <http://www.gnu.org/licenses/>.
+
+#include "main.h"
+
+template<typename Scalar> void trmm(int size,int othersize)
+{
+ typedef typename NumTraits<Scalar>::Real RealScalar;
+
+ Matrix<Scalar,Dynamic,Dynamic,ColMajor> tri(size,size), upTri(size,size), loTri(size,size);
+ Matrix<Scalar,Dynamic,Dynamic,ColMajor> ge1(size,othersize), ge2(10,size), ge3;
+ Matrix<Scalar,Dynamic,Dynamic,RowMajor> rge3;
+
+ Scalar s1 = ei_random<Scalar>(),
+ s2 = ei_random<Scalar>();
+
+ tri.setRandom();
+ loTri = tri.template triangularView<LowerTriangular>();
+ upTri = tri.template triangularView<UpperTriangular>();
+ ge1.setRandom();
+ ge2.setRandom();
+
+ VERIFY_IS_APPROX( ge3 = tri.template triangularView<LowerTriangular>() * ge1, loTri * ge1);
+ VERIFY_IS_APPROX(rge3 = tri.template triangularView<LowerTriangular>() * ge1, loTri * ge1);
+ VERIFY_IS_APPROX( ge3 = tri.template triangularView<UpperTriangular>() * ge1, upTri * ge1);
+ VERIFY_IS_APPROX(rge3 = tri.template triangularView<UpperTriangular>() * ge1, upTri * ge1);
+ VERIFY_IS_APPROX( ge3 = (s1*tri.adjoint()).template triangularView<UpperTriangular>() * (s2*ge1), s1*loTri.adjoint() * (s2*ge1));
+ VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<UpperTriangular>() * ge1, loTri.adjoint() * ge1);
+ VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge1, upTri.adjoint() * ge1);
+ VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge1, upTri.adjoint() * ge1);
+ VERIFY_IS_APPROX( ge3 = tri.template triangularView<LowerTriangular>() * ge2.adjoint(), loTri * ge2.adjoint());
+ VERIFY_IS_APPROX(rge3 = tri.template triangularView<LowerTriangular>() * ge2.adjoint(), loTri * ge2.adjoint());
+ VERIFY_IS_APPROX( ge3 = tri.template triangularView<UpperTriangular>() * ge2.adjoint(), upTri * ge2.adjoint());
+ VERIFY_IS_APPROX(rge3 = tri.template triangularView<UpperTriangular>() * ge2.adjoint(), upTri * ge2.adjoint());
+ VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<UpperTriangular>() * ge2.adjoint(), loTri.adjoint() * ge2.adjoint());
+ VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<UpperTriangular>() * ge2.adjoint(), loTri.adjoint() * ge2.adjoint());
+ VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
+ VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<LowerTriangular>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
+}
+
+void test_product_trmm()
+{
+ for(int i = 0; i < g_repeat ; i++)
+ {
+ trmm<float>(ei_random<int>(1,320),ei_random<int>(1,320));
+ trmm<std::complex<double> >(ei_random<int>(1,320),ei_random<int>(1,320));
+ }
+}
diff --git a/test/product_triangular.cpp b/test/product_trmv.cpp
index 876fb4388..876fb4388 100644
--- a/test/product_triangular.cpp
+++ b/test/product_trmv.cpp
diff --git a/test/product_trsm.cpp b/test/product_trsm.cpp
index 80df5861e..bda158048 100644
--- a/test/product_trsm.cpp
+++ b/test/product_trsm.cpp
@@ -85,6 +85,7 @@ template<typename Scalar> void trsm(int size,int cols)
solve_ref(rmLhs.template triangularView<UpperTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
}
+
void test_product_trsm()
{
for(int i = 0; i < g_repeat ; i++)