diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-04-01 11:47:31 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-04-01 11:47:31 -0700 |
commit | 4e2f6de1a8fd9a659dc40ed54fedff9abdef3b1f (patch) | |
tree | e510ad53ee053b68327462c0e6d944db5dc362d0 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h | |
parent | 45e65fbb7791e453f88f959111cff45e0fb7dd6f (diff) |
Add support for custom packed Lhs/Rhs blocks in tensor contractions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h | 45 |
1 files changed, 38 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 142492603..1be823fd1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -24,12 +24,17 @@ enum { */ /// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which /// is scalar * for CoeffLoader. -template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> struct CoeffLoader; -template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t, - int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, - template <class> class MakePointer_ = MakePointer> class BaseTensorContractionMapper; +template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> +struct CoeffLoader; -template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> struct CoeffLoader { +template <typename Scalar, typename Index, int side, typename Tensor, + typename nocontract_t, typename contract_t, int packet_size, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, + template <class> class MakePointer_ = MakePointer> +class BaseTensorContractionMapper; + +template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> +struct CoeffLoader { enum { DirectOffsets = false }; @@ -40,6 +45,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer eigen_assert(false && "unsupported"); } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type + data() const { + eigen_assert(false && "unsupported"); + return NULL; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); } template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -48,12 +59,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer return m_tensor.template packet<LoadMode>(index); } - private: const Tensor m_tensor; }; -template <typename Tensor, template <class> class MakePointer_> struct CoeffLoader<Tensor, true, MakePointer_> { +template <typename Tensor, template <class> class MakePointer_> +struct CoeffLoader<Tensor, true, MakePointer_> { enum { DirectOffsets = true }; @@ -64,6 +75,11 @@ template <typename Tensor, template <class> class MakePointer_> struct CoeffLoad m_data += offset; } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type + data() const { + return m_data; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); } template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -214,6 +230,17 @@ class SimpleTensorContractionMapper { return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1; } + const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const { + return m_tensor; + } + + const nocontract_t& nocontract_strides() const { + return m_nocontract_strides; + } + const nocontract_t& ij_strides() const { return m_ij_strides; } + const contract_t& contract_strides() const { return m_contract_strides; } + const contract_t& k_strides() const { return m_k_strides; } + protected: CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor; const nocontract_t m_nocontract_strides; @@ -445,6 +472,10 @@ class TensorContractionSubMapper { return false; } + const ParentMapper& base_mapper() const { return m_base_mapper; } + Index vert_offset() const { return m_vert_offset; } + Index horiz_offset() const { return m_horiz_offset; } + private: ParentMapper m_base_mapper; const Index m_vert_offset; |