aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2016-12-03 21:25:04 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2016-12-03 21:25:04 +0100
commit4465d20403921f9acd705ba3955057d729fd04b7 (patch)
tree261c87146119bf470bf49eb769f7e547d27c5303 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
parent6a5fe860985311cc275c4bb7000e0d261822c756 (diff)
Add missing generic load methods.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h26
1 files changed, 23 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
index 9b2cb3ff6..a2d7c7414 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
@@ -235,9 +235,9 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
typedef typename Tensor::PacketReturnType Packet;
typedef typename unpacket_traits<Packet>::half HalfPacket;
- template <int AlignmentType>
+ template <typename PacketT,int AlignmentType>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
+ EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
// whole method makes column major assumption
// don't need to add offsets for now (because operator handles that)
@@ -275,7 +275,13 @@ class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar,
}
data[packet_size - 1] = this->m_tensor.coeff(last);
- return pload<Packet>(data);
+ return pload<PacketT>(data);
+ }
+
+ template <int AlignmentType>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
+ return this->load<Packet,AlignmentType>(i,j);
}
template <int AlignmentType>
@@ -322,6 +328,12 @@ class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, con
data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
return pload<typename Tensor::PacketReturnType>(data);
}
+ template <typename PacketT,int> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
+ EIGEN_ALIGN_MAX Scalar data[1];
+ data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
+ return pload<PacketT>(data);
+ }
template <int> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
return loadPacket(i, j);
@@ -385,6 +397,14 @@ class TensorContractionSubMapper {
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
}
+ template <typename PacketT, int AlignmentType>
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
+ if (UseDirectOffsets) {
+ return m_base_mapper.template load<PacketT,AlignmentType>(i, j);
+ }
+ return m_base_mapper.template loadPacket<PacketT,AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
+ }
+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
if (UseDirectOffsets) {
return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);