diff options
Diffstat (limited to 'unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h')
-rw-r--r-- | unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h | 56 |
1 files changed, 35 insertions, 21 deletions
diff --git a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h index b8f2cba17..446fcac16 100644 --- a/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h +++ b/unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h @@ -48,8 +48,8 @@ class KroneckerProductBase : public ReturnByValue<Derived> */ Scalar coeff(Index row, Index col) const { - return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * - m_B.coeff(row % m_B.rows(), col % m_B.cols()); + return m_A.coeff(typename Lhs::Index(row / m_B.rows()), typename Lhs::Index(col / m_B.cols())) * + m_B.coeff(typename Rhs::Index(row % m_B.rows()), typename Rhs::Index(col % m_B.cols())); } /*! @@ -59,7 +59,7 @@ class KroneckerProductBase : public ReturnByValue<Derived> Scalar coeff(Index i) const { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived); - return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size()); + return m_A.coeff(typename Lhs::Index(i / m_A.size())) * m_B.coeff(typename Rhs::Index(i % m_A.size())); } protected: @@ -148,38 +148,53 @@ template<typename Lhs, typename Rhs> template<typename Dest> void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const { - typedef typename Base::Index Index; - const Index Br = m_B.rows(), - Bc = m_B.cols(); - dst.resize(this->rows(), this->cols()); + typedef typename Dest::Index DestIndex; + const typename Rhs::Index Br = m_B.rows(), + Bc = m_B.cols(); + eigen_assert(this->rows() <= NumTraits<DestIndex>::highest()); + eigen_assert(this->cols() <= NumTraits<DestIndex>::highest()); + dst.resize(DestIndex(this->rows()), DestIndex(this->cols())); dst.resizeNonZeros(0); + // 1 - evaluate the operands if needed: + typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1; + typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned; + const Lhs1 lhs1(m_A); + typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1; + typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned; + const Rhs1 rhs1(m_B); + + // 2 - construct respective iterators + typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator; + typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator; + // compute number of non-zeros per innervectors of dst { VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols()); - for (Index kA=0; kA < m_A.outerSize(); ++kA) - for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA) + for (typename Lhs::Index kA=0; kA < m_A.outerSize(); ++kA) + for (LhsInnerIterator itA(lhs1,kA); itA; ++itA) nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++; VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols()); - for (Index kB=0; kB < m_B.outerSize(); ++kB) - for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB) + for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB) + for (RhsInnerIterator itB(rhs1,kB); itB; ++itB) nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++; Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose(); dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size())); } - for (Index kA=0; kA < m_A.outerSize(); ++kA) + for (typename Lhs::Index kA=0; kA < m_A.outerSize(); ++kA) { - for (Index kB=0; kB < m_B.outerSize(); ++kB) + for (typename Rhs::Index kB=0; kB < m_B.outerSize(); ++kB) { - for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA) + for (LhsInnerIterator itA(lhs1,kA); itA; ++itA) { - for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB) + for (RhsInnerIterator itB(rhs1,kB); itB; ++itB) { - const Index i = itA.row() * Br + itB.row(), - j = itA.col() * Bc + itB.col(); + const DestIndex + i = DestIndex(itA.row() * Br + itB.row()), + j = DestIndex(itA.col() * Bc + itB.col()); dst.insert(i,j) = itA.value() * itB.value(); } } @@ -201,8 +216,7 @@ struct traits<KroneckerProduct<_Lhs,_Rhs> > Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, - MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret, - CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost + MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret }; typedef Matrix<Scalar,Rows,Cols> ReturnType; @@ -215,7 +229,7 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> > typedef typename remove_all<_Lhs>::type Lhs; typedef typename remove_all<_Rhs>::type Rhs; typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; - typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind>::ret StorageKind; + typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind, scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret StorageKind; typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index; enum { @@ -235,7 +249,7 @@ struct traits<KroneckerProductSparse<_Lhs,_Rhs> > CoeffReadCost = Dynamic }; - typedef SparseMatrix<Scalar> ReturnType; + typedef SparseMatrix<Scalar, 0, Index> ReturnType; }; } // end namespace internal |