aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-12-21 16:42:56 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-12-21 16:42:56 -0800
commit4236aebe103b0fa54f3b9e7e3c0c12094fa6e200 (patch)
treeeb3c18dd42f192442834877eb34a4c65a8cc7b20 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent519d63d350222ddbed5db1883a8fb2c7aab4b4e9 (diff)
Simplified the contraction code`
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h45
1 files changed, 18 insertions, 27 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index c446ba1af..442c14fac 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -720,24 +720,20 @@ protected:
const LhsScalar* leftData = m_leftImpl.data();
const RhsScalar* rightData = m_rightImpl.data();
- libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m);
- libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k);
- libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m);
+ const libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m);
+ const libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k);
+ const libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m);
- libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc);
+ const libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc);
// Use bigger stride to avoid hitting same cache line too often.
// This consistently gives +~0.5 Gflops.
- libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>(
+ const libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>(
kc % 32 == 0 ? kc + 16 : kc
);
// Kernel for the general case (not edges)
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar> kernel;
- const LhsScalar *ap;
- const RhsScalar *bp;
- const Scalar *cp;
-
LhsScalar* blockA = NULL;
RhsScalar* panelB = NULL;
@@ -748,8 +744,8 @@ protected:
panelB = static_cast<RhsScalar*>(this->m_device.allocate(nc_outer * stride_panelB * sizeof(RhsScalar)));
}
- Index kernel_stride_A = copyA ? stride_blockA : stride_A;
- Index kernel_stride_B = copyB ? stride_panelB : stride_B;
+ const Index kernel_stride_A = copyA ? stride_blockA : stride_A;
+ const Index kernel_stride_B = copyB ? stride_panelB : stride_B;
kernel = internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, mc, nc, kc, kernel_stride_A, kernel_stride_B, stride_C, 1, 1, blocking.prefetch());
// Outer blocking
@@ -763,6 +759,7 @@ protected:
// Inner blocking
for (Index ki = ki_outer; ki < mini(ki_outer+kc_outer, k); ki += kc) {
const Index actual_kc = mini(ki_outer+kc_outer, mini(ki+kc, k)) - ki;
+ const float beta = ki == 0 ? 0 : 1;
if (copyB) {
if (transposeB) {
@@ -775,8 +772,8 @@ protected:
for (Index mi = mi_outer; mi < mini(mi_outer+mc_outer, m); mi += mc) {
const Index actual_mc = mini(mi_outer+mc_outer, mini(mi+mc, m)) - mi;
- const LhsScalar * a = transposeA ? leftData + mi*stride_A + ki :
- leftData + ki*stride_A + mi;
+ const LhsScalar* a = transposeA ? leftData + mi*stride_A + ki :
+ leftData + ki*stride_A + mi;
if (copyA) {
if (transposeA) {
@@ -785,30 +782,24 @@ protected:
internal::pack_simple<LhsScalar, Index>(blockA, a, actual_kc, actual_mc, stride_blockA, stride_A);
}
}
+ const LhsScalar* actual_a = copyA ? blockA : a;
for (Index ni = ni_outer; ni < mini(ni_outer+nc_outer, n); ni += nc) {
const Index actual_nc = mini(ni_outer+nc_outer, mini(ni+nc, n)) - ni;
- const RhsScalar * b = rightData + ni*stride_B + ki;
- Scalar * c = buffer + ni*stride_C + mi;
- cp = c + nc*stride_C;
-
- const LhsScalar * actual_a = copyA ? blockA : a;
- const Index actual_lda = copyA ? stride_blockA : stride_A;
- ap = copyA ? blockA : a;
-
- const RhsScalar * actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b;
- const Index actual_ldb = copyB ? stride_panelB : stride_B;
- bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B;
+ const RhsScalar* b = rightData + ni*stride_B + ki;
+ Scalar* c = buffer + ni*stride_C + mi;
+ const Scalar* cp = c + nc*stride_C;
- float beta = ki == 0 ? 0 : 1;
+ const RhsScalar* actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b;
+ const RhsScalar* bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B;
if (actual_mc == mc && actual_kc == kc && actual_nc == nc && beta == 1) {
// Most used, cached kernel.
- kernel(actual_a, actual_b, c, ap, bp, cp);
+ kernel(actual_a, actual_b, c, actual_a, bp, cp);
} else {
// Edges - use libxsmm kernel cache.
- internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, actual_lda, actual_ldb, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, ap, bp, cp);
+ internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, kernel_stride_A, kernel_stride_B, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, actual_a, bp, cp);
}
}
}