aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-03 08:51:33 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-03 08:51:33 -0800
commitb1789c112b5cf8d478a03786c6c1243320aefd47 (patch)
tree16d91868395df93f058bda1acb469b93070e7e1a /unsupported/Eigen/CXX11/src
parent2dde63499c4ef836a0d9dfd443494d863ad62b16 (diff)
Improved handling of 1d tensors
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h98
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h12
2 files changed, 99 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index c530b27a7..8e898619d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -48,7 +48,7 @@ class BaseTensorContractionMapper {
m_k_strides(k_strides) { }
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE void prefetch(int /*i*/) { }
+ EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
@@ -142,6 +142,13 @@ class BaseTensorContractionMapper {
return IndexPair<Index>(linidx[0], linidx[1]);
}
+ Index firstAligned(Index size) const {
+ return size;
+ }
+ Index stride() const {
+ return 1;
+ }
+
protected:
const Tensor m_tensor;
const nocontract_t m_nocontract_strides;
@@ -202,6 +209,18 @@ class TensorContractionSubMapper {
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
+ template <typename PacketT, int AlignmentType>
+ EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
+ EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ return loadPacket(i);
+ }
+
+ template <typename Packet>
+ bool aligned(Index /*i*/) const {
+ return false;
+ }
+
private:
const ParentMapper& m_base_mapper;
const Index m_vert_offset;
@@ -220,6 +239,7 @@ class TensorContractionInputMapper
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
+ typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
@@ -233,6 +253,10 @@ class TensorContractionInputMapper
return SubMapper(*this, i, j);
}
+ EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
+ return VectorMapper(*this, i, j);
+ }
+
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
@@ -306,6 +330,7 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
+ typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
@@ -319,6 +344,10 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
return SubMapper(*this, i, j);
}
+ EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
+ return VectorMapper(*this, i, j);
+ }
+
typedef typename packet_traits<Scalar>::type Packet;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
@@ -592,41 +621,80 @@ struct TensorContractionEvaluatorBase
if (this->m_lhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
- static_cast<const Derived*>(this)->template evalTyped<true, true, true, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
}
else {
- static_cast<const Derived*>(this)->template evalTyped<true, true, false, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
- static_cast<const Derived*>(this)->template evalTyped<true, false, true, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
}
else {
- static_cast<const Derived*>(this)->template evalTyped<true, false, false, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
}
}
}
else {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
- static_cast<const Derived*>(this)->template evalTyped<false, true, true, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
}
else {
- static_cast<const Derived*>(this)->template evalTyped<false, true, false, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
}
}
else {
if (this->m_rhs_inner_dim_reordered) {
- static_cast<const Derived*>(this)->template evalTyped<false, false, true, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
}
else {
- static_cast<const Derived*>(this)->template evalTyped<false, false, false, Unaligned>(buffer);
+ static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
}
}
}
}
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+ void evalGemv(Scalar* buffer) const {
+ const Index rows = m_i_size;
+ const Index cols = m_k_size;
+
+ typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar;
+ typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar;
+ typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator;
+ typedef TensorEvaluator<RightArgType, Device> RightEvaluator;
+ const int lhs_packet_size = internal::packet_traits<LhsScalar>::size;
+ const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
+ typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
+ LeftEvaluator, left_nocontract_t,
+ contract_t, lhs_packet_size,
+ lhs_inner_dim_contiguous,
+ false, Unaligned> LhsMapper;
+
+ typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
+ RightEvaluator, right_nocontract_t,
+ contract_t, rhs_packet_size,
+ rhs_inner_dim_contiguous,
+ rhs_inner_dim_reordered, Unaligned> RhsMapper;
+
+ LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
+ m_left_contracting_strides, m_k_strides);
+ RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
+ m_right_contracting_strides, m_k_strides);
+
+ const Scalar alpha(1);
+ const Index resIncr(1);
+
+ // zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
+ m_device.memset(buffer, 0, rows * sizeof(Scalar));
+
+ internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
+ rows, cols, lhs, rhs,
+ buffer, resIncr, alpha);
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_leftImpl.cleanup();
m_rightImpl.cleanup();
@@ -707,7 +775,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Base(op, device) { }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
- EIGEN_DEVICE_FUNC void evalTyped(Scalar* buffer) const {
+ void evalProduct(Scalar* buffer) const {
+ if (this->m_j_size == 1) {
+ this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
+ return;
+ }
+
+ evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
+ }
+
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+ EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index cf1352a31..f0e9bb616 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -93,7 +93,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Base(op, device) {}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
- void evalTyped(Scalar* buffer) const {
+ void evalProduct(Scalar* buffer) const {
+ if (this->m_j_size == 1) {
+ this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
+ return;
+ }
+
+ evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
+ }
+
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+ void evalGemm(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;