diff options
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 98 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 12 |
2 files changed, 99 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index c530b27a7..8e898619d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -48,7 +48,7 @@ class BaseTensorContractionMapper { m_k_strides(k_strides) { } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE void prefetch(int /*i*/) { } + EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row) const { @@ -142,6 +142,13 @@ class BaseTensorContractionMapper { return IndexPair<Index>(linidx[0], linidx[1]); } + Index firstAligned(Index size) const { + return size; + } + Index stride() const { + return 1; + } + protected: const Tensor m_tensor; const nocontract_t m_nocontract_strides; @@ -202,6 +209,18 @@ class TensorContractionSubMapper { return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); } + template <typename PacketT, int AlignmentType> + EIGEN_ALWAYS_INLINE PacketT load(Index i) const { + EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE); + return loadPacket(i); + } + + template <typename Packet> + bool aligned(Index /*i*/) const { + return false; + } + private: const ParentMapper& m_base_mapper; const Index m_vert_offset; @@ -220,6 +239,7 @@ class TensorContractionInputMapper public: typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base; typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper; + typedef SubMapper VectorMapper; TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides, @@ -233,6 +253,10 @@ class TensorContractionInputMapper return SubMapper(*this, i, j); } + EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { + return VectorMapper(*this, i, j); + } + typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::half HalfPacket; @@ -306,6 +330,7 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co public: typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base; typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper; + typedef SubMapper VectorMapper; TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides, @@ -319,6 +344,10 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co return SubMapper(*this, i, j); } + EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { + return VectorMapper(*this, i, j); + } + typedef typename packet_traits<Scalar>::type Packet; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { @@ -592,41 +621,80 @@ struct TensorContractionEvaluatorBase if (this->m_lhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalTyped<true, true, true, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer); } else { - static_cast<const Derived*>(this)->template evalTyped<true, true, false, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalTyped<true, false, true, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer); } else { - static_cast<const Derived*>(this)->template evalTyped<true, false, false, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer); } } } else { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalTyped<false, true, true, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer); } else { - static_cast<const Derived*>(this)->template evalTyped<false, true, false, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalTyped<false, false, true, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer); } else { - static_cast<const Derived*>(this)->template evalTyped<false, false, false, Unaligned>(buffer); + static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer); } } } } + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> + void evalGemv(Scalar* buffer) const { + const Index rows = m_i_size; + const Index cols = m_k_size; + + typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar; + typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar; + typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator; + typedef TensorEvaluator<RightArgType, Device> RightEvaluator; + const int lhs_packet_size = internal::packet_traits<LhsScalar>::size; + const int rhs_packet_size = internal::packet_traits<RhsScalar>::size; + typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, + LeftEvaluator, left_nocontract_t, + contract_t, lhs_packet_size, + lhs_inner_dim_contiguous, + false, Unaligned> LhsMapper; + + typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, + RightEvaluator, right_nocontract_t, + contract_t, rhs_packet_size, + rhs_inner_dim_contiguous, + rhs_inner_dim_reordered, Unaligned> RhsMapper; + + LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, + m_left_contracting_strides, m_k_strides); + RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides, + m_right_contracting_strides, m_k_strides); + + const Scalar alpha(1); + const Index resIncr(1); + + // zero out the result buffer (which must be of size at least rows * sizeof(Scalar) + m_device.memset(buffer, 0, rows * sizeof(Scalar)); + + internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run( + rows, cols, lhs, rhs, + buffer, resIncr, alpha); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); m_rightImpl.cleanup(); @@ -707,7 +775,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT Base(op, device) { } template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> - EIGEN_DEVICE_FUNC void evalTyped(Scalar* buffer) const { + void evalProduct(Scalar* buffer) const { + if (this->m_j_size == 1) { + this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); + return; + } + + evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); + } + + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> + EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index cf1352a31..f0e9bb616 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -93,7 +93,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT Base(op, device) {} template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> - void evalTyped(Scalar* buffer) const { + void evalProduct(Scalar* buffer) const { + if (this->m_j_size == 1) { + this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); + return; + } + + evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); + } + + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> + void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; |