From 6d472d83754e0f16db1deb69218e10c2b21268b1 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 19 Jan 2016 17:22:05 -0800 Subject: Moved the contraction mapping code to its own file to make the code more manageable. --- .../CXX11/src/Tensor/TensorContractionMapper.h | 377 +++++++++++++++++++++ 1 file changed, 377 insertions(+) create mode 100644 unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h new file mode 100644 index 000000000..b25b34d61 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -0,0 +1,377 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H +#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H + +namespace Eigen { + +namespace internal { + +enum { + Rhs = 0, + Lhs = 1, +}; + +/* + * Implementation of the Eigen blas_data_mapper class for tensors. + */ +template +class SimpleTensorContractionMapper { + public: + EIGEN_DEVICE_FUNC + SimpleTensorContractionMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) : + m_tensor(tensor), + m_nocontract_strides(nocontract_strides), + m_ij_strides(ij_strides), + m_contract_strides(contract_strides), + m_k_strides(k_strides) { } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row) const { + // column major assumption + return operator()(row, 0); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const { + return m_tensor.coeff(computeIndex(row, col)); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const { + const bool left = (side == Lhs); + Index nocontract_val = left ? row : col; + Index linidx = 0; + for (int i = static_cast(array_size::value) - 1; i > 0; i--) { + const Index idx = nocontract_val / m_ij_strides[i]; + linidx += idx * m_nocontract_strides[i]; + nocontract_val -= idx * m_ij_strides[i]; + } + if (array_size::value > array_size::value) { + if (side == Lhs && inner_dim_contiguous) { + eigen_assert(m_nocontract_strides[0] == 1); + linidx += nocontract_val; + } else { + linidx += nocontract_val * m_nocontract_strides[0]; + } + } + + Index contract_val = left ? col : row; + for (int i = static_cast(array_size::value) - 1; i > 0; i--) { + const Index idx = contract_val / m_k_strides[i]; + linidx += idx * m_contract_strides[i]; + contract_val -= idx * m_k_strides[i]; + } + + if(array_size::value > 0) { + if (side == Rhs && inner_dim_contiguous) { + eigen_assert(m_contract_strides[0] == 1); + linidx += contract_val; + } else { + linidx += contract_val * m_contract_strides[0]; + } + } + + return linidx; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE IndexPair computeIndexPair(Index row, Index col, const Index distance) const { + const bool left = (side == Lhs); + Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; + Index linidx[2] = {0, 0}; + for (int i = static_cast(array_size::value) - 1; i > 0; i--) { + const Index idx0 = nocontract_val[0] / m_ij_strides[i]; + const Index idx1 = nocontract_val[1] / m_ij_strides[i]; + linidx[0] += idx0 * m_nocontract_strides[i]; + linidx[1] += idx1 * m_nocontract_strides[i]; + nocontract_val[0] -= idx0 * m_ij_strides[i]; + nocontract_val[1] -= idx1 * m_ij_strides[i]; + } + if (array_size::value > array_size::value) { + if (side == Lhs && inner_dim_contiguous) { + eigen_assert(m_nocontract_strides[0] == 1); + linidx[0] += nocontract_val[0]; + linidx[1] += nocontract_val[1]; + } else { + linidx[0] += nocontract_val[0] * m_nocontract_strides[0]; + linidx[1] += nocontract_val[1] * m_nocontract_strides[0]; + } + } + + Index contract_val[2] = {left ? col : row, left ? col : row + distance}; + for (int i = static_cast(array_size::value) - 1; i > 0; i--) { + const Index idx0 = contract_val[0] / m_k_strides[i]; + const Index idx1 = contract_val[1] / m_k_strides[i]; + linidx[0] += idx0 * m_contract_strides[i]; + linidx[1] += idx1 * m_contract_strides[i]; + contract_val[0] -= idx0 * m_k_strides[i]; + contract_val[1] -= idx1 * m_k_strides[i]; + } + + if (side == Rhs && inner_dim_contiguous) { + eigen_assert(m_contract_strides[0] == 1); + linidx[0] += contract_val[0]; + linidx[1] += contract_val[1]; + } else { + linidx[0] += contract_val[0] * m_contract_strides[0]; + linidx[1] += contract_val[1] * m_contract_strides[0]; + } + return IndexPair(linidx[0], linidx[1]); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const { + // Only claim alignment when we can compute the actual stride (ie when we're + // dealing with the lhs with inner_dim_contiguous. This is because the + // matrix-vector product relies on the stride when dealing with aligned inputs. + return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { + return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1; + } + + protected: + const Tensor m_tensor; + const nocontract_t m_nocontract_strides; + const nocontract_t m_ij_strides; + const contract_t m_contract_strides; + const contract_t m_k_strides; +}; + + +template +class BaseTensorContractionMapper : public SimpleTensorContractionMapper +{ + public: + typedef SimpleTensorContractionMapper ParentMapper; + + EIGEN_DEVICE_FUNC + BaseTensorContractionMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) : + ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } + + typedef typename packet_traits::type Packet; + typedef typename packet_traits::half HalfPacket; + + template + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { + // whole method makes column major assumption + + // don't need to add offsets for now (because operator handles that) + // current code assumes packet size must be a multiple of 2 + EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); + + if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) { + const Index index = this->computeIndex(i, j); + eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1); + return this->m_tensor.template packet(index); + } + + const IndexPair indexPair = this->computeIndexPair(i, j, packet_size - 1); + const Index first = indexPair.first; + const Index last = indexPair.second; + + // We can always do optimized packet reads from left hand side right now, because + // the vertical matrix dimension on the left hand side is never contracting. + // On the right hand side we need to check if the contracting dimensions may have + // been shuffled first. + if (Tensor::PacketAccess && + (side == Lhs || internal::array_size::value <= 1 || !inner_dim_reordered) && + (last - first) == (packet_size - 1)) { + + return this->m_tensor.template packet(first); + } + + EIGEN_ALIGN_MAX Scalar data[packet_size]; + + data[0] = this->m_tensor.coeff(first); + for (Index k = 1; k < packet_size - 1; k += 2) { + const IndexPair internal_pair = this->computeIndexPair(i + k, j, 1); + data[k] = this->m_tensor.coeff(internal_pair.first); + data[k + 1] = this->m_tensor.coeff(internal_pair.second); + } + data[packet_size - 1] = this->m_tensor.coeff(last); + + return pload(data); + } + + template + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { + // whole method makes column major assumption + + // don't need to add offsets for now (because operator handles that) + const Index half_packet_size = unpacket_traits::size; + if (half_packet_size == packet_size) { + return loadPacket(i, j); + } + EIGEN_ALIGN_MAX Scalar data[half_packet_size]; + for (Index k = 0; k < half_packet_size; k++) { + data[k] = operator()(i + k, j); + } + return pload(data); + } +}; + + +template +class BaseTensorContractionMapper : public SimpleTensorContractionMapper +{ + public: + typedef SimpleTensorContractionMapper ParentMapper; + + EIGEN_DEVICE_FUNC + BaseTensorContractionMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) : + ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } + + typedef typename packet_traits::type Packet; + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { + EIGEN_ALIGN_MAX Scalar data[1]; + data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); + return pload::type>(data); + } + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const { + return loadPacket(i, j); + } +}; + +template +class TensorContractionInputMapper; + +template +class TensorContractionSubMapper { + public: + typedef typename packet_traits::type Packet; + typedef typename packet_traits::half HalfPacket; + + typedef TensorContractionInputMapper ParentMapper; + typedef TensorContractionSubMapper Self; + typedef Self LinearMapper; + + EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_base_mapper(i + m_vert_offset, m_horiz_offset); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const { + return m_base_mapper(i + m_vert_offset, j + m_horiz_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return m_base_mapper.template loadPacket(i + m_vert_offset, m_horiz_offset); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { + return m_base_mapper.template loadPacket(i + m_vert_offset, j + m_horiz_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { + return m_base_mapper.template loadHalfPacket(i + m_vert_offset, m_horiz_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const { + m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); + } + + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { + EIGEN_STATIC_ASSERT((internal::is_same::value), YOU_MADE_A_PROGRAMMING_MISTAKE); + const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned; + return m_base_mapper.template loadPacket(i + m_vert_offset, m_horiz_offset); + } + + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const { + return false; + } + + private: + const ParentMapper& m_base_mapper; + const Index m_vert_offset; + const Index m_horiz_offset; +}; + + +template +class TensorContractionInputMapper + : public BaseTensorContractionMapper { + + public: + typedef BaseTensorContractionMapper Base; + typedef TensorContractionSubMapper SubMapper; + typedef SubMapper VectorMapper; + + EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, + const nocontract_t& nocontract_strides, + const nocontract_t& ij_strides, + const contract_t& contract_strides, + const contract_t& k_strides) + : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { + return VectorMapper(*this, i, j); + } +}; + + + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H -- cgit v1.2.3