aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-04-01 11:47:31 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-04-01 11:47:31 -0700
commit4e2f6de1a8fd9a659dc40ed54fedff9abdef3b1f (patch)
treee510ad53ee053b68327462c0e6d944db5dc362d0 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
parent45e65fbb7791e453f88f959111cff45e0fb7dd6f (diff)
Add support for custom packed Lhs/Rhs blocks in tensor contractions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h45
1 files changed, 38 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
index 142492603..1be823fd1 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
@@ -24,12 +24,17 @@ enum {
*/
/// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which
/// is scalar * for CoeffLoader.
-template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> struct CoeffLoader;
-template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
- int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
- template <class> class MakePointer_ = MakePointer> class BaseTensorContractionMapper;
+template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
+struct CoeffLoader;
-template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> struct CoeffLoader {
+template <typename Scalar, typename Index, int side, typename Tensor,
+ typename nocontract_t, typename contract_t, int packet_size,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ template <class> class MakePointer_ = MakePointer>
+class BaseTensorContractionMapper;
+
+template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
+struct CoeffLoader {
enum {
DirectOffsets = false
};
@@ -40,6 +45,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer
eigen_assert(false && "unsupported");
}
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
+ data() const {
+ eigen_assert(false && "unsupported");
+ return NULL;
+ }
+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -48,12 +59,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer
return m_tensor.template packet<LoadMode>(index);
}
-
private:
const Tensor m_tensor;
};
-template <typename Tensor, template <class> class MakePointer_> struct CoeffLoader<Tensor, true, MakePointer_> {
+template <typename Tensor, template <class> class MakePointer_>
+struct CoeffLoader<Tensor, true, MakePointer_> {
enum {
DirectOffsets = true
};
@@ -64,6 +75,11 @@ template <typename Tensor, template <class> class MakePointer_> struct CoeffLoad
m_data += offset;
}
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
+ data() const {
+ return m_data;
+ }
+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@@ -214,6 +230,17 @@ class SimpleTensorContractionMapper {
return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
}
+ const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const {
+ return m_tensor;
+ }
+
+ const nocontract_t& nocontract_strides() const {
+ return m_nocontract_strides;
+ }
+ const nocontract_t& ij_strides() const { return m_ij_strides; }
+ const contract_t& contract_strides() const { return m_contract_strides; }
+ const contract_t& k_strides() const { return m_k_strides; }
+
protected:
CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
const nocontract_t m_nocontract_strides;
@@ -445,6 +472,10 @@ class TensorContractionSubMapper {
return false;
}
+ const ParentMapper& base_mapper() const { return m_base_mapper; }
+ Index vert_offset() const { return m_vert_offset; }
+ Index horiz_offset() const { return m_horiz_offset; }
+
private:
ParentMapper m_base_mapper;
const Index m_vert_offset;