From 4e2f6de1a8fd9a659dc40ed54fedff9abdef3b1f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 1 Apr 2019 11:47:31 -0700 Subject: Add support for custom packed Lhs/Rhs blocks in tensor contractions --- .../CXX11/src/Tensor/TensorContractionMapper.h | 45 ++++++++++++++++++---- 1 file changed, 38 insertions(+), 7 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 142492603..1be823fd1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -24,12 +24,17 @@ enum { */ /// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which /// is scalar * for CoeffLoader. -template class MakePointer_ = MakePointer> struct CoeffLoader; -template class MakePointer_ = MakePointer> class BaseTensorContractionMapper; +template class MakePointer_ = MakePointer> +struct CoeffLoader; -template class MakePointer_> struct CoeffLoader { +template class MakePointer_ = MakePointer> +class BaseTensorContractionMapper; + +template class MakePointer_> +struct CoeffLoader { enum { DirectOffsets = false }; @@ -40,6 +45,12 @@ template class MakePointer eigen_assert(false && "unsupported"); } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_::Type + data() const { + eigen_assert(false && "unsupported"); + return NULL; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -48,12 +59,12 @@ template class MakePointer return m_tensor.template packet(index); } - private: const Tensor m_tensor; }; -template class MakePointer_> struct CoeffLoader { +template class MakePointer_> +struct CoeffLoader { enum { DirectOffsets = true }; @@ -64,6 +75,11 @@ template class MakePointer_> struct CoeffLoad m_data += offset; } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_::Type + data() const { + return m_data; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -214,6 +230,17 @@ class SimpleTensorContractionMapper { return ((side == Lhs) && inner_dim_contiguous && array_size::value > 0) ? m_contract_strides[0] : 1; } + const CoeffLoader& tensor() const { + return m_tensor; + } + + const nocontract_t& nocontract_strides() const { + return m_nocontract_strides; + } + const nocontract_t& ij_strides() const { return m_ij_strides; } + const contract_t& contract_strides() const { return m_contract_strides; } + const contract_t& k_strides() const { return m_k_strides; } + protected: CoeffLoader m_tensor; const nocontract_t m_nocontract_strides; @@ -445,6 +472,10 @@ class TensorContractionSubMapper { return false; } + const ParentMapper& base_mapper() const { return m_base_mapper; } + Index vert_offset() const { return m_vert_offset; } + Index horiz_offset() const { return m_horiz_offset; } + private: ParentMapper m_base_mapper; const Index m_vert_offset; -- cgit v1.2.3