diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-12 11:32:27 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-12 11:32:27 -0800 |
commit | d920d57f38e07739403a6c1e224c74fec5a36e6f (patch) | |
tree | ca940402a68713459e9944942f7b4492e1dd99d9 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | |
parent | bd7d901da9bd824ee2a7a94d3b0c8668d77f5ff2 (diff) |
Improved the performance of the contraction of a 2d tensor with a 1d tensor by a factor of 3 or more. This helps speedup LSTM neural networks.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 48 |
1 files changed, 26 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index eda93a1de..63d0c6f68 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -32,7 +32,7 @@ enum { template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t, - int packet_size, bool inner_dim_contiguous> + int packet_size, bool inner_dim_contiguous, int Alignment> class SimpleTensorContractionMapper { public: EIGEN_DEVICE_FUNC @@ -144,11 +144,11 @@ class SimpleTensorContractionMapper { return IndexPair<Index>(linidx[0], linidx[1]); } - Index firstAligned(Index size) const { - return size; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const { + return (Alignment == Aligned) ? 0 : size; } - Index stride() const { - return 1; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { + return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1; } protected: @@ -165,10 +165,10 @@ template<typename Scalar, typename Index, int side, typename nocontract_t, typename contract_t, int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> - class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> +class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> { public: - typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper; + typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper; EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, @@ -181,6 +181,7 @@ template<typename Scalar, typename Index, int side, typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::half HalfPacket; + template <int AlignmentType = Alignment> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { // whole method makes column major assumption @@ -192,7 +193,7 @@ template<typename Scalar, typename Index, int side, if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) { const Index index = this->computeIndex(i, j); eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1); - return this->m_tensor.template packet<Alignment>(index); + return this->m_tensor.template packet<AlignmentType>(index); } const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1); @@ -207,7 +208,7 @@ template<typename Scalar, typename Index, int side, (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) && (last - first) == (packet_size - 1)) { - return this->m_tensor.template packet<Alignment>(first); + return this->m_tensor.template packet<AlignmentType>(first); } EIGEN_ALIGN_MAX Scalar data[packet_size]; @@ -223,6 +224,7 @@ template<typename Scalar, typename Index, int side, return pload<Packet>(data); } + template <int AlignmentType = Alignment> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { // whole method makes column major assumption @@ -230,7 +232,7 @@ template<typename Scalar, typename Index, int side, // don't need to add offsets for now (because operator handles that) const Index half_packet_size = unpacket_traits<HalfPacket>::size; if (half_packet_size == packet_size) { - return loadPacket(i, j); + return loadPacket<AlignmentType>(i, j); } EIGEN_ALIGN_MAX Scalar data[half_packet_size]; for (Index k = 0; k < half_packet_size; k++) { @@ -246,10 +248,10 @@ template<typename Scalar, typename Index, int side, typename nocontract_t, typename contract_t, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> -class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> +class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> { public: - typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper; + typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper; EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, @@ -260,13 +262,13 @@ class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, con ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC + template <int> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { EIGEN_ALIGN_MAX Scalar data[1]; data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); return pload<typename packet_traits<Scalar>::type>(data); } - EIGEN_DEVICE_FUNC + template <int> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const { return loadPacket(i, j); } @@ -304,14 +306,14 @@ class TensorContractionSubMapper { } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { - return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset); + return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { - return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset); + return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { - return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset); + return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const { @@ -325,12 +327,12 @@ class TensorContractionSubMapper { template <typename PacketT, int AlignmentType> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE); - EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE); - return loadPacket(i); + const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned; + return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset); } template <typename Packet> - EIGEN_DEVICE_FUNC bool aligned(Index) const { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const { return false; } @@ -741,17 +743,19 @@ struct TensorContractionEvaluatorBase typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size; const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size; + const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned; + const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned; typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t, contract_t, lhs_packet_size, lhs_inner_dim_contiguous, - false, Unaligned> LhsMapper; + false, lhs_alignment> LhsMapper; typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t, contract_t, rhs_packet_size, rhs_inner_dim_contiguous, - rhs_inner_dim_reordered, Unaligned> RhsMapper; + rhs_inner_dim_reordered, rhs_alignment> RhsMapper; LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides); |