diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-12-21 16:42:56 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-12-21 16:42:56 -0800 |
commit | 4236aebe103b0fa54f3b9e7e3c0c12094fa6e200 (patch) | |
tree | eb3c18dd42f192442834877eb34a4c65a8cc7b20 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | |
parent | 519d63d350222ddbed5db1883a8fb2c7aab4b4e9 (diff) |
Simplified the contraction code`
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 45 |
1 files changed, 18 insertions, 27 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index c446ba1af..442c14fac 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -720,24 +720,20 @@ protected: const LhsScalar* leftData = m_leftImpl.data(); const RhsScalar* rightData = m_rightImpl.data(); - libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m); - libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k); - libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m); + const libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m); + const libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k); + const libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m); - libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc); + const libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc); // Use bigger stride to avoid hitting same cache line too often. // This consistently gives +~0.5 Gflops. - libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>( + const libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>( kc % 32 == 0 ? kc + 16 : kc ); // Kernel for the general case (not edges) internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar> kernel; - const LhsScalar *ap; - const RhsScalar *bp; - const Scalar *cp; - LhsScalar* blockA = NULL; RhsScalar* panelB = NULL; @@ -748,8 +744,8 @@ protected: panelB = static_cast<RhsScalar*>(this->m_device.allocate(nc_outer * stride_panelB * sizeof(RhsScalar))); } - Index kernel_stride_A = copyA ? stride_blockA : stride_A; - Index kernel_stride_B = copyB ? stride_panelB : stride_B; + const Index kernel_stride_A = copyA ? stride_blockA : stride_A; + const Index kernel_stride_B = copyB ? stride_panelB : stride_B; kernel = internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, mc, nc, kc, kernel_stride_A, kernel_stride_B, stride_C, 1, 1, blocking.prefetch()); // Outer blocking @@ -763,6 +759,7 @@ protected: // Inner blocking for (Index ki = ki_outer; ki < mini(ki_outer+kc_outer, k); ki += kc) { const Index actual_kc = mini(ki_outer+kc_outer, mini(ki+kc, k)) - ki; + const float beta = ki == 0 ? 0 : 1; if (copyB) { if (transposeB) { @@ -775,8 +772,8 @@ protected: for (Index mi = mi_outer; mi < mini(mi_outer+mc_outer, m); mi += mc) { const Index actual_mc = mini(mi_outer+mc_outer, mini(mi+mc, m)) - mi; - const LhsScalar * a = transposeA ? leftData + mi*stride_A + ki : - leftData + ki*stride_A + mi; + const LhsScalar* a = transposeA ? leftData + mi*stride_A + ki : + leftData + ki*stride_A + mi; if (copyA) { if (transposeA) { @@ -785,30 +782,24 @@ protected: internal::pack_simple<LhsScalar, Index>(blockA, a, actual_kc, actual_mc, stride_blockA, stride_A); } } + const LhsScalar* actual_a = copyA ? blockA : a; for (Index ni = ni_outer; ni < mini(ni_outer+nc_outer, n); ni += nc) { const Index actual_nc = mini(ni_outer+nc_outer, mini(ni+nc, n)) - ni; - const RhsScalar * b = rightData + ni*stride_B + ki; - Scalar * c = buffer + ni*stride_C + mi; - cp = c + nc*stride_C; - - const LhsScalar * actual_a = copyA ? blockA : a; - const Index actual_lda = copyA ? stride_blockA : stride_A; - ap = copyA ? blockA : a; - - const RhsScalar * actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b; - const Index actual_ldb = copyB ? stride_panelB : stride_B; - bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B; + const RhsScalar* b = rightData + ni*stride_B + ki; + Scalar* c = buffer + ni*stride_C + mi; + const Scalar* cp = c + nc*stride_C; - float beta = ki == 0 ? 0 : 1; + const RhsScalar* actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b; + const RhsScalar* bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B; if (actual_mc == mc && actual_kc == kc && actual_nc == nc && beta == 1) { // Most used, cached kernel. - kernel(actual_a, actual_b, c, ap, bp, cp); + kernel(actual_a, actual_b, c, actual_a, bp, cp); } else { // Edges - use libxsmm kernel cache. - internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, actual_lda, actual_ldb, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, ap, bp, cp); + internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, kernel_stride_A, kernel_stride_B, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, actual_a, bp, cp); } } } |