diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-19 17:22:05 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-19 17:22:05 -0800 |
commit | 6d472d83754e0f16db1deb69218e10c2b21268b1 (patch) | |
tree | 38b222275e3297f9fa910e0fcb26d59f54d27ff3 /unsupported | |
parent | b3b722905f3df26a34cdda4f2cee74aa62403040 (diff) |
Moved the contraction mapping code to its own file to make the code more manageable.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/Tensor | 1 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 357 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h | 377 |
3 files changed, 378 insertions, 357 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index 7481a9ddb..1c5734383 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -88,6 +88,7 @@ typedef unsigned __int64 uint64_t; #include "src/Tensor/TensorReductionCuda.h" #include "src/Tensor/TensorArgMax.h" #include "src/Tensor/TensorConcatenation.h" +#include "src/Tensor/TensorContractionMapper.h" #include "src/Tensor/TensorContraction.h" #include "src/Tensor/TensorContractionThreadPool.h" #include "src/Tensor/TensorContractionCuda.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 72a378dfd..506696ae9 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -21,363 +21,6 @@ namespace Eigen { */ namespace internal { -enum { - Rhs = 0, - Lhs = 1, -}; - -/* - * Implementation of the Eigen blas_data_mapper class for tensors. - */ -template<typename Scalar, typename Index, int side, - typename Tensor, - typename nocontract_t, typename contract_t, - int packet_size, bool inner_dim_contiguous, int Alignment> -class SimpleTensorContractionMapper { - public: - EIGEN_DEVICE_FUNC - SimpleTensorContractionMapper(const Tensor& tensor, - const nocontract_t& nocontract_strides, - const nocontract_t& ij_strides, - const contract_t& contract_strides, - const contract_t& k_strides) : - m_tensor(tensor), - m_nocontract_strides(nocontract_strides), - m_ij_strides(ij_strides), - m_contract_strides(contract_strides), - m_k_strides(k_strides) { } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Scalar operator()(Index row) const { - // column major assumption - return operator()(row, 0); - } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const { - return m_tensor.coeff(computeIndex(row, col)); - } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const { - const bool left = (side == Lhs); - Index nocontract_val = left ? row : col; - Index linidx = 0; - for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { - const Index idx = nocontract_val / m_ij_strides[i]; - linidx += idx * m_nocontract_strides[i]; - nocontract_val -= idx * m_ij_strides[i]; - } - if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { - if (side == Lhs && inner_dim_contiguous) { - eigen_assert(m_nocontract_strides[0] == 1); - linidx += nocontract_val; - } else { - linidx += nocontract_val * m_nocontract_strides[0]; - } - } - - Index contract_val = left ? col : row; - for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { - const Index idx = contract_val / m_k_strides[i]; - linidx += idx * m_contract_strides[i]; - contract_val -= idx * m_k_strides[i]; - } - - if(array_size<contract_t>::value > 0) { - if (side == Rhs && inner_dim_contiguous) { - eigen_assert(m_contract_strides[0] == 1); - linidx += contract_val; - } else { - linidx += contract_val * m_contract_strides[0]; - } - } - - return linidx; - } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const { - const bool left = (side == Lhs); - Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; - Index linidx[2] = {0, 0}; - for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { - const Index idx0 = nocontract_val[0] / m_ij_strides[i]; - const Index idx1 = nocontract_val[1] / m_ij_strides[i]; - linidx[0] += idx0 * m_nocontract_strides[i]; - linidx[1] += idx1 * m_nocontract_strides[i]; - nocontract_val[0] -= idx0 * m_ij_strides[i]; - nocontract_val[1] -= idx1 * m_ij_strides[i]; - } - if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { - if (side == Lhs && inner_dim_contiguous) { - eigen_assert(m_nocontract_strides[0] == 1); - linidx[0] += nocontract_val[0]; - linidx[1] += nocontract_val[1]; - } else { - linidx[0] += nocontract_val[0] * m_nocontract_strides[0]; - linidx[1] += nocontract_val[1] * m_nocontract_strides[0]; - } - } - - Index contract_val[2] = {left ? col : row, left ? col : row + distance}; - for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { - const Index idx0 = contract_val[0] / m_k_strides[i]; - const Index idx1 = contract_val[1] / m_k_strides[i]; - linidx[0] += idx0 * m_contract_strides[i]; - linidx[1] += idx1 * m_contract_strides[i]; - contract_val[0] -= idx0 * m_k_strides[i]; - contract_val[1] -= idx1 * m_k_strides[i]; - } - - if (side == Rhs && inner_dim_contiguous) { - eigen_assert(m_contract_strides[0] == 1); - linidx[0] += contract_val[0]; - linidx[1] += contract_val[1]; - } else { - linidx[0] += contract_val[0] * m_contract_strides[0]; - linidx[1] += contract_val[1] * m_contract_strides[0]; - } - return IndexPair<Index>(linidx[0], linidx[1]); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const { - // Only claim alignment when we can compute the actual stride (ie when we're - // dealing with the lhs with inner_dim_contiguous. This is because the - // matrix-vector product relies on the stride when dealing with aligned inputs. - return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { - return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1; - } - - protected: - const Tensor m_tensor; - const nocontract_t m_nocontract_strides; - const nocontract_t m_ij_strides; - const contract_t m_contract_strides; - const contract_t m_k_strides; -}; - - -template<typename Scalar, typename Index, int side, - typename Tensor, - typename nocontract_t, typename contract_t, - int packet_size, bool inner_dim_contiguous, - bool inner_dim_reordered, int Alignment> -class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> -{ - public: - typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper; - - EIGEN_DEVICE_FUNC - BaseTensorContractionMapper(const Tensor& tensor, - const nocontract_t& nocontract_strides, - const nocontract_t& ij_strides, - const contract_t& contract_strides, - const contract_t& k_strides) : - ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } - - typedef typename packet_traits<Scalar>::type Packet; - typedef typename packet_traits<Scalar>::half HalfPacket; - - template <int AlignmentType = Alignment> - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { - // whole method makes column major assumption - - // don't need to add offsets for now (because operator handles that) - // current code assumes packet size must be a multiple of 2 - EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); - - if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) { - const Index index = this->computeIndex(i, j); - eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1); - return this->m_tensor.template packet<AlignmentType>(index); - } - - const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1); - const Index first = indexPair.first; - const Index last = indexPair.second; - - // We can always do optimized packet reads from left hand side right now, because - // the vertical matrix dimension on the left hand side is never contracting. - // On the right hand side we need to check if the contracting dimensions may have - // been shuffled first. - if (Tensor::PacketAccess && - (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) && - (last - first) == (packet_size - 1)) { - - return this->m_tensor.template packet<AlignmentType>(first); - } - - EIGEN_ALIGN_MAX Scalar data[packet_size]; - - data[0] = this->m_tensor.coeff(first); - for (Index k = 1; k < packet_size - 1; k += 2) { - const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1); - data[k] = this->m_tensor.coeff(internal_pair.first); - data[k + 1] = this->m_tensor.coeff(internal_pair.second); - } - data[packet_size - 1] = this->m_tensor.coeff(last); - - return pload<Packet>(data); - } - - template <int AlignmentType = Alignment> - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { - // whole method makes column major assumption - - // don't need to add offsets for now (because operator handles that) - const Index half_packet_size = unpacket_traits<HalfPacket>::size; - if (half_packet_size == packet_size) { - return loadPacket<AlignmentType>(i, j); - } - EIGEN_ALIGN_MAX Scalar data[half_packet_size]; - for (Index k = 0; k < half_packet_size; k++) { - data[k] = operator()(i + k, j); - } - return pload<HalfPacket>(data); - } -}; - - -template<typename Scalar, typename Index, int side, - typename Tensor, - typename nocontract_t, typename contract_t, - bool inner_dim_contiguous, - bool inner_dim_reordered, int Alignment> -class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> -{ - public: - typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper; - - EIGEN_DEVICE_FUNC - BaseTensorContractionMapper(const Tensor& tensor, - const nocontract_t& nocontract_strides, - const nocontract_t& ij_strides, - const contract_t& contract_strides, - const contract_t& k_strides) : - ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } - - typedef typename packet_traits<Scalar>::type Packet; - template <int> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { - EIGEN_ALIGN_MAX Scalar data[1]; - data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); - return pload<typename packet_traits<Scalar>::type>(data); - } - template <int> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const { - return loadPacket(i, j); - } -}; - -template<typename Scalar, typename Index, int side, - typename Tensor, - typename nocontract_t, typename contract_t, - int packet_size, - bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> -class TensorContractionInputMapper; - -template<typename Scalar, typename Index, int side, - typename Tensor, - typename nocontract_t, typename contract_t, - int packet_size, - bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> -class TensorContractionSubMapper { - public: - typedef typename packet_traits<Scalar>::type Packet; - typedef typename packet_traits<Scalar>::half HalfPacket; - - typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper; - typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self; - typedef Self LinearMapper; - - EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) - : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { - return m_base_mapper(i + m_vert_offset, m_horiz_offset); - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const { - return m_base_mapper(i + m_vert_offset, j + m_horiz_offset); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { - return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset); - } - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { - return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { - return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const { - m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { - return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); - } - - template <typename PacketT, int AlignmentType> - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { - EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE); - const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned; - return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset); - } - - template <typename Packet> - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const { - return false; - } - - private: - const ParentMapper& m_base_mapper; - const Index m_vert_offset; - const Index m_horiz_offset; -}; - - -template<typename Scalar, typename Index, int side, - typename Tensor, - typename nocontract_t, typename contract_t, - int packet_size, - bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> -class TensorContractionInputMapper - : public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> { - - public: - typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base; - typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper; - typedef SubMapper VectorMapper; - - EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, - const nocontract_t& nocontract_strides, - const nocontract_t& ij_strides, - const contract_t& contract_strides, - const contract_t& k_strides) - : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } - - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { - return SubMapper(*this, i, j); - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { - return VectorMapper(*this, i, j); - } -}; - - - template<typename Dimensions, typename LhsXprType, typename RhsXprType> struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h new file mode 100644 index 000000000..b25b34d61 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -0,0 +1,377 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H +#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H + +namespace Eigen { + +namespace internal { + +enum { + Rhs = 0, + Lhs = 1, +}; + +/* + * Implementation of the Eigen blas_data_mapper class for tensors. + */ +template<typename Scalar, typename Index, int side, + typename Tensor, + typename nocontract_t, typename contract_t, + int packet_size, bool inner_dim_contiguous, int Alignment> +class SimpleTensorContractionMapper { + public: + EIGEN_DEVICE_FUNC + SimpleTensorContractionMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) : + m_tensor(tensor), + m_nocontract_strides(nocontract_strides), + m_ij_strides(ij_strides), + m_contract_strides(contract_strides), + m_k_strides(k_strides) { } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row) const { + // column major assumption + return operator()(row, 0); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const { + return m_tensor.coeff(computeIndex(row, col)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const { + const bool left = (side == Lhs); + Index nocontract_val = left ? row : col; + Index linidx = 0; + for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { + const Index idx = nocontract_val / m_ij_strides[i]; + linidx += idx * m_nocontract_strides[i]; + nocontract_val -= idx * m_ij_strides[i]; + } + if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { + if (side == Lhs && inner_dim_contiguous) { + eigen_assert(m_nocontract_strides[0] == 1); + linidx += nocontract_val; + } else { + linidx += nocontract_val * m_nocontract_strides[0]; + } + } + + Index contract_val = left ? col : row; + for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { + const Index idx = contract_val / m_k_strides[i]; + linidx += idx * m_contract_strides[i]; + contract_val -= idx * m_k_strides[i]; + } + + if(array_size<contract_t>::value > 0) { + if (side == Rhs && inner_dim_contiguous) { + eigen_assert(m_contract_strides[0] == 1); + linidx += contract_val; + } else { + linidx += contract_val * m_contract_strides[0]; + } + } + + return linidx; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const { + const bool left = (side == Lhs); + Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; + Index linidx[2] = {0, 0}; + for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { + const Index idx0 = nocontract_val[0] / m_ij_strides[i]; + const Index idx1 = nocontract_val[1] / m_ij_strides[i]; + linidx[0] += idx0 * m_nocontract_strides[i]; + linidx[1] += idx1 * m_nocontract_strides[i]; + nocontract_val[0] -= idx0 * m_ij_strides[i]; + nocontract_val[1] -= idx1 * m_ij_strides[i]; + } + if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { + if (side == Lhs && inner_dim_contiguous) { + eigen_assert(m_nocontract_strides[0] == 1); + linidx[0] += nocontract_val[0]; + linidx[1] += nocontract_val[1]; + } else { + linidx[0] += nocontract_val[0] * m_nocontract_strides[0]; + linidx[1] += nocontract_val[1] * m_nocontract_strides[0]; + } + } + + Index contract_val[2] = {left ? col : row, left ? col : row + distance}; + for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { + const Index idx0 = contract_val[0] / m_k_strides[i]; + const Index idx1 = contract_val[1] / m_k_strides[i]; + linidx[0] += idx0 * m_contract_strides[i]; + linidx[1] += idx1 * m_contract_strides[i]; + contract_val[0] -= idx0 * m_k_strides[i]; + contract_val[1] -= idx1 * m_k_strides[i]; + } + + if (side == Rhs && inner_dim_contiguous) { + eigen_assert(m_contract_strides[0] == 1); + linidx[0] += contract_val[0]; + linidx[1] += contract_val[1]; + } else { + linidx[0] += contract_val[0] * m_contract_strides[0]; + linidx[1] += contract_val[1] * m_contract_strides[0]; + } + return IndexPair<Index>(linidx[0], linidx[1]); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const { + // Only claim alignment when we can compute the actual stride (ie when we're + // dealing with the lhs with inner_dim_contiguous. This is because the + // matrix-vector product relies on the stride when dealing with aligned inputs. + return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { + return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1; + } + + protected: + const Tensor m_tensor; + const nocontract_t m_nocontract_strides; + const nocontract_t m_ij_strides; + const contract_t m_contract_strides; + const contract_t m_k_strides; +}; + + +template<typename Scalar, typename Index, int side, + typename Tensor, + typename nocontract_t, typename contract_t, + int packet_size, bool inner_dim_contiguous, + bool inner_dim_reordered, int Alignment> +class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> +{ + public: + typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper; + + EIGEN_DEVICE_FUNC + BaseTensorContractionMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) : + ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } + + typedef typename packet_traits<Scalar>::type Packet; + typedef typename packet_traits<Scalar>::half HalfPacket; + + template <int AlignmentType = Alignment> + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { + // whole method makes column major assumption + + // don't need to add offsets for now (because operator handles that) + // current code assumes packet size must be a multiple of 2 + EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); + + if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) { + const Index index = this->computeIndex(i, j); + eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1); + return this->m_tensor.template packet<AlignmentType>(index); + } + + const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1); + const Index first = indexPair.first; + const Index last = indexPair.second; + + // We can always do optimized packet reads from left hand side right now, because + // the vertical matrix dimension on the left hand side is never contracting. + // On the right hand side we need to check if the contracting dimensions may have + // been shuffled first. + if (Tensor::PacketAccess && + (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) && + (last - first) == (packet_size - 1)) { + + return this->m_tensor.template packet<AlignmentType>(first); + } + + EIGEN_ALIGN_MAX Scalar data[packet_size]; + + data[0] = this->m_tensor.coeff(first); + for (Index k = 1; k < packet_size - 1; k += 2) { + const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1); + data[k] = this->m_tensor.coeff(internal_pair.first); + data[k + 1] = this->m_tensor.coeff(internal_pair.second); + } + data[packet_size - 1] = this->m_tensor.coeff(last); + + return pload<Packet>(data); + } + + template <int AlignmentType = Alignment> + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { + // whole method makes column major assumption + + // don't need to add offsets for now (because operator handles that) + const Index half_packet_size = unpacket_traits<HalfPacket>::size; + if (half_packet_size == packet_size) { + return loadPacket<AlignmentType>(i, j); + } + EIGEN_ALIGN_MAX Scalar data[half_packet_size]; + for (Index k = 0; k < half_packet_size; k++) { + data[k] = operator()(i + k, j); + } + return pload<HalfPacket>(data); + } +}; + + +template<typename Scalar, typename Index, int side, + typename Tensor, + typename nocontract_t, typename contract_t, + bool inner_dim_contiguous, + bool inner_dim_reordered, int Alignment> +class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> +{ + public: + typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper; + + EIGEN_DEVICE_FUNC + BaseTensorContractionMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) : + ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } + + typedef typename packet_traits<Scalar>::type Packet; + template <int> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { + EIGEN_ALIGN_MAX Scalar data[1]; + data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); + return pload<typename packet_traits<Scalar>::type>(data); + } + template <int> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const { + return loadPacket(i, j); + } +}; + +template<typename Scalar, typename Index, int side, + typename Tensor, + typename nocontract_t, typename contract_t, + int packet_size, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> +class TensorContractionInputMapper; + +template<typename Scalar, typename Index, int side, + typename Tensor, + typename nocontract_t, typename contract_t, + int packet_size, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> +class TensorContractionSubMapper { + public: + typedef typename packet_traits<Scalar>::type Packet; + typedef typename packet_traits<Scalar>::half HalfPacket; + + typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper; + typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self; + typedef Self LinearMapper; + + EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_base_mapper(i + m_vert_offset, m_horiz_offset); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const { + return m_base_mapper(i + m_vert_offset, j + m_horiz_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { + return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { + return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const { + m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); + } + + template <typename PacketT, int AlignmentType> + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { + EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE); + const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned; + return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset); + } + + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const { + return false; + } + + private: + const ParentMapper& m_base_mapper; + const Index m_vert_offset; + const Index m_horiz_offset; +}; + + +template<typename Scalar, typename Index, int side, + typename Tensor, + typename nocontract_t, typename contract_t, + int packet_size, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> +class TensorContractionInputMapper + : public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> { + + public: + typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base; + typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper; + typedef SubMapper VectorMapper; + + EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) + : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { + return VectorMapper(*this, i, j); + } +}; + + + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H |