aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-22 16:36:30 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-22 16:36:30 -0800
commit3aeeca32af00b1921b4424d7be2e03bbaeaa05b4 (patch)
tree3642b1288127f1910380d115d176eaecd8fcda44 /unsupported/Eigen/CXX11
parent4beb447e27baaa19081e835bd6aba76e9b02cc67 (diff)
Leverage the new blocking code in the tensor contraction code.
Diffstat (limited to 'unsupported/Eigen/CXX11')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h4
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h5
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h8
3 files changed, 8 insertions, 9 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 624e814e2..e6a008ba7 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -582,10 +582,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
OutputMapper output(buffer, m);
- typedef typename internal::gemm_blocking_space<ColMajor, LhsScalar, RhsScalar, Dynamic, Dynamic, Dynamic> BlockingType;
-
// Sizes of the blocks to load in cache. See the Goto paper for details.
- BlockingType blocking(m, n, k, 1, true);
+ internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
const Index kc = blocking.kc();
const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc());
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
index 9b6d18090..63c8ae126 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h
@@ -426,15 +426,16 @@ class TensorContractionSubMapper {
};
-template<typename Scalar, typename Index, int side,
+template<typename Scalar_, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper
- : public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
+ : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
public:
+ typedef Scalar_ Scalar;
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 576bea295..51a3b9490 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -176,10 +176,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// compute block sizes (which depend on number of threads)
const Index num_threads = this->m_device.numThreads();
- Index mc = m;
- Index nc = n;
- Index kc = k;
- internal::computeProductBlockingSizes<LhsScalar,RhsScalar,1>(kc, mc, nc, num_threads);
+ internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, num_threads);
+ Index mc = blocking.mc();
+ Index nc = blocking.nc();
+ Index kc = blocking.kc();
eigen_assert(mc <= m);
eigen_assert(nc <= n);
eigen_assert(kc <= k);