diff options
author | Gael Guennebaud <g.gael@free.fr> | 2019-01-30 16:48:01 +0100 |
---|---|---|
committer | Gael Guennebaud <g.gael@free.fr> | 2019-01-30 16:48:01 +0100 |
commit | d586686924c2783f56bd514c9365afeecc3e84f6 (patch) | |
tree | a1da7f53e8e307ab7bec0c4818561f571c3f813f /unsupported | |
parent | eb4c6bb22dfb6c83deaea65fe5d1bbb4cdc1f8bb (diff) |
Workaround lack of support for arbitrary packet-type in Tensor by manually loading half/quarter packets in tensor contraction mapper.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 64dfcd297..142492603 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -241,8 +241,10 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } template <typename PacketT,int AlignmentType> - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + typename internal::enable_if<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>::type + load(Index i, Index j) const + { // whole method makes column major assumption // don't need to add offsets for now (because operator handles that) @@ -284,6 +286,29 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, } template <typename PacketT,int AlignmentType> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + typename internal::enable_if<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>::type + load(Index i, Index j) const + { + const Index requested_packet_size = internal::unpacket_traits<PacketT>::size; + EIGEN_ALIGN_MAX Scalar data[requested_packet_size]; + + const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1); + const Index first = indexPair.first; + const Index lastIdx = indexPair.second; + + data[0] = this->m_tensor.coeff(first); + for (Index k = 1; k < requested_packet_size - 1; k += 2) { + const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1); + data[k] = this->m_tensor.coeff(internal_pair.first); + data[k + 1] = this->m_tensor.coeff(internal_pair.second); + } + data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx); + + return pload<PacketT>(data); + } + + template <typename PacketT,int AlignmentType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const { return this->load<PacketT,AlignmentType>(i,j); |