aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-03 19:18:07 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-10-03 19:18:07 -0700
commitaf2e5995e2ba48384024bbc8432bd6dbbebf71d2 (patch)
treeceec127775a3ff52f815836cb77efba63d31719b /unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h
parent12693928228922ecf8fa3fcf14341d195e376a11 (diff)
Improved support for CUDA devices.
Improved contractions on GPU
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h1206
1 files changed, 1206 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h
new file mode 100644
index 000000000..babe33fff
--- /dev/null
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h
@@ -0,0 +1,1206 @@
+// This file is part of Eigen, a lightweight C++ template library
+// for linear algebra.
+//
+// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla
+// Public License v. 2.0. If a copy of the MPL was not distributed
+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
+#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
+
+#if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
+
+namespace Eigen {
+
+template<typename Scalar, typename Index, typename LhsMapper,
+ typename RhsMapper, typename OutputMapper, bool needs_edge_check>
+__device__ EIGEN_STRONG_INLINE void
+EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
+ const OutputMapper output, volatile Scalar* lhs_shmem, volatile Scalar* rhs_shmem,
+ const Index m_size, const Index n_size, const Index k_size) {
+
+ const Index m_block_idx = blockIdx.x;
+ const Index n_block_idx = blockIdx.y;
+
+ const Index base_m = 64 * m_block_idx;
+ const Index base_n = 64 * n_block_idx;
+
+ // declare and initialize 64 registers for output 8x8 block
+
+ // prefetch registers
+ Scalar lhs_pf0;
+ Scalar lhs_pf1;
+ Scalar lhs_pf2;
+ Scalar lhs_pf3;
+ Scalar lhs_pf4;
+ Scalar lhs_pf5;
+ Scalar lhs_pf6;
+ Scalar lhs_pf7;
+
+ Scalar rhs_pf0;
+ Scalar rhs_pf1;
+ Scalar rhs_pf2;
+ Scalar rhs_pf3;
+ Scalar rhs_pf4;
+ Scalar rhs_pf5;
+ Scalar rhs_pf6;
+ Scalar rhs_pf7;
+
+ // shared memory is formatted
+ // (contract idx in block, nocontract idx in block, block idx)
+ // where block idx is column major. This transposition limits the number of
+ // bank conflicts when reading the LHS. The core idea is that since the contracting
+ // index is shared by both sides, then the contracting index should be in threadIdx.x.
+
+ // On the LHS, we pad each row inside of each block with an extra element. This makes
+ // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
+ // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
+
+ // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
+ // conflicts on writes and also none on reads.
+
+ // storage indices
+ const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
+ const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
+
+ const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
+ const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
+ const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
+ const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
+ const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
+ const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
+ const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
+ const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
+
+ const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
+ const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
+ const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
+ const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
+ const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
+ const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
+ const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
+ const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
+
+ // in the loading code, the following variables are important:
+ // threadIdx.x: the vertical position in an 8x8 block
+ // threadIdx.y: the vertical index of the 8x8 block in the grid
+ // threadIdx.z: the horizontal position in an 8x8 block
+ // k: the horizontal index of the 8x8 block in the grid
+ //
+ // The k parameter is implicit (it was the loop counter for a loop that went
+ // from 0 to <8, but now that loop is unrolled in the below code.
+
+ const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
+ const Index lhs_vert = base_m + load_idx_vert;
+
+#define prefetchIntoRegisters(base_k) \
+ { \
+ lhs_pf0 = Scalar(0); \
+ lhs_pf1 = Scalar(0); \
+ lhs_pf2 = Scalar(0); \
+ lhs_pf3 = Scalar(0); \
+ lhs_pf4 = Scalar(0); \
+ lhs_pf5 = Scalar(0); \
+ lhs_pf6 = Scalar(0); \
+ lhs_pf7 = Scalar(0); \
+ \
+ rhs_pf0 = Scalar(0); \
+ rhs_pf1 = Scalar(0); \
+ rhs_pf2 = Scalar(0); \
+ rhs_pf3 = Scalar(0); \
+ rhs_pf4 = Scalar(0); \
+ rhs_pf5 = Scalar(0); \
+ rhs_pf6 = Scalar(0); \
+ rhs_pf7 = Scalar(0); \
+ \
+ if (!needs_edge_check || lhs_vert < m_size) { \
+ const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
+ const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
+ const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
+ const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
+ const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
+ const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
+ const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
+ const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
+ \
+ if (!needs_edge_check || lhs_horiz_7 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
+ lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
+ lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
+ lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
+ lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
+ lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
+ lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
+ } else if (lhs_horiz_6 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
+ lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
+ lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
+ lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
+ lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
+ lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
+ } else if (lhs_horiz_5 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
+ lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
+ lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
+ lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
+ lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
+ } else if (lhs_horiz_4 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
+ lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
+ lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
+ lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
+ } else if (lhs_horiz_3 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
+ lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
+ lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
+ } else if (lhs_horiz_2 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
+ lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
+ } else if (lhs_horiz_1 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
+ } else if (lhs_horiz_0 < k_size) { \
+ lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
+ } \
+ } \
+ \
+ const Index rhs_vert = base_k + load_idx_vert; \
+ if (!needs_edge_check || rhs_vert < k_size) { \
+ const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
+ const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
+ const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
+ const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
+ const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
+ const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
+ const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
+ const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
+ \
+ if (rhs_horiz_7 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
+ rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
+ rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
+ rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
+ rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
+ rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
+ rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
+ } else if (rhs_horiz_6 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
+ rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
+ rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
+ rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
+ rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
+ rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
+ } else if (rhs_horiz_5 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
+ rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
+ rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
+ rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
+ rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
+ } else if (rhs_horiz_4 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
+ rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
+ rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
+ rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
+ } else if (rhs_horiz_3 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
+ rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
+ rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
+ } else if (rhs_horiz_2 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
+ rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
+ } else if (rhs_horiz_1 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
+ } else if (rhs_horiz_0 < n_size) { \
+ rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
+ } \
+ } \
+ } \
+
+#define writeRegToShmem(_) \
+ lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
+ rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
+ \
+ lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
+ rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
+ \
+ lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
+ rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
+ \
+ lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
+ rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
+ \
+ lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
+ rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
+ \
+ lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
+ rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
+ \
+ lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
+ rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
+ \
+ lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
+ rhs_shmem[rhs_store_idx_7] = rhs_pf7; \
+
+ // declare and initialize result array
+#define res(i, j) _res_##i##j
+#define initResultRow(i) \
+ Scalar res(i, 0) = Scalar(0); \
+ Scalar res(i, 1) = Scalar(0); \
+ Scalar res(i, 2) = Scalar(0); \
+ Scalar res(i, 3) = Scalar(0); \
+ Scalar res(i, 4) = Scalar(0); \
+ Scalar res(i, 5) = Scalar(0); \
+ Scalar res(i, 6) = Scalar(0); \
+ Scalar res(i, 7) = Scalar(0); \
+
+ initResultRow(0);
+ initResultRow(1);
+ initResultRow(2);
+ initResultRow(3);
+ initResultRow(4);
+ initResultRow(5);
+ initResultRow(6);
+ initResultRow(7);
+#undef initResultRow
+
+ for (Index base_k = 0; base_k < k_size; base_k += 64) {
+ // wait for previous iteration to finish with shmem. Despite common sense,
+ // the code is a bit faster with this here then at bottom of loop
+ __syncthreads();
+
+ prefetchIntoRegisters(base_k);
+ writeRegToShmem();
+
+ #undef prefetchIntoRegisters
+ #undef writeRegToShmem
+
+ // wait for shared mem packing to be done before starting computation
+ __syncthreads();
+
+ // compute 8x8 matrix product by outer product. This involves packing one column
+ // of LHS and one row of RHS into registers (takes 16 registers).
+
+#define lcol(i) _lcol##i
+ Scalar lcol(0);
+ Scalar lcol(1);
+ Scalar lcol(2);
+ Scalar lcol(3);
+ Scalar lcol(4);
+ Scalar lcol(5);
+ Scalar lcol(6);
+ Scalar lcol(7);
+
+#define rrow(j) _rrow##j
+ Scalar rrow(0);
+ Scalar rrow(1);
+ Scalar rrow(2);
+ Scalar rrow(3);
+ Scalar rrow(4);
+ Scalar rrow(5);
+ Scalar rrow(6);
+ Scalar rrow(7);
+
+ // Now x corresponds to k, y to m, and z to n
+ const volatile Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
+ const volatile Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
+
+#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
+#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
+
+#define loadData(i, j) \
+ lcol(0) = lhs_element(0, j); \
+ rrow(0) = rhs_element(i, 0); \
+ lcol(1) = lhs_element(1, j); \
+ rrow(1) = rhs_element(i, 1); \
+ lcol(2) = lhs_element(2, j); \
+ rrow(2) = rhs_element(i, 2); \
+ lcol(3) = lhs_element(3, j); \
+ rrow(3) = rhs_element(i, 3); \
+ lcol(4) = lhs_element(4, j); \
+ rrow(4) = rhs_element(i, 4); \
+ lcol(5) = lhs_element(5, j); \
+ rrow(5) = rhs_element(i, 5); \
+ lcol(6) = lhs_element(6, j); \
+ rrow(6) = rhs_element(i, 6); \
+ lcol(7) = lhs_element(7, j); \
+ rrow(7) = rhs_element(i, 7); \
+
+#define computeCol(j) \
+ res(0, j) += lcol(0) * rrow(j); \
+ res(1, j) += lcol(1) * rrow(j); \
+ res(2, j) += lcol(2) * rrow(j); \
+ res(3, j) += lcol(3) * rrow(j); \
+ res(4, j) += lcol(4) * rrow(j); \
+ res(5, j) += lcol(5) * rrow(j); \
+ res(6, j) += lcol(6) * rrow(j); \
+ res(7, j) += lcol(7) * rrow(j); \
+
+#define computePass(i) \
+ loadData(i, i); \
+ \
+ computeCol(0); \
+ computeCol(1); \
+ computeCol(2); \
+ computeCol(3); \
+ computeCol(4); \
+ computeCol(5); \
+ computeCol(6); \
+ computeCol(7); \
+
+ computePass(0);
+ computePass(1);
+ computePass(2);
+ computePass(3);
+ computePass(4);
+ computePass(5);
+ computePass(6);
+ computePass(7);
+
+#undef lcol
+#undef rrow
+#undef lhs_element
+#undef rhs_element
+#undef loadData
+#undef computeCol
+#undef computePass
+ } // end loop over k
+
+ // we've now iterated over all of the large (ie width 64) k blocks and
+ // accumulated results in registers. At this point thread (x, y, z) contains
+ // the sum across all big k blocks of the product of little k block of index (x, y)
+ // with block of index (y, z). To compute the final output, we need to reduce
+ // the 8 threads over y by summation.
+#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
+
+#define reduceRow(i, mask) \
+ shuffleInc(i, 0, mask); \
+ shuffleInc(i, 1, mask); \
+ shuffleInc(i, 2, mask); \
+ shuffleInc(i, 3, mask); \
+ shuffleInc(i, 4, mask); \
+ shuffleInc(i, 5, mask); \
+ shuffleInc(i, 6, mask); \
+ shuffleInc(i, 7, mask); \
+
+#define reduceMatrix(mask) \
+ reduceRow(0, mask); \
+ reduceRow(1, mask); \
+ reduceRow(2, mask); \
+ reduceRow(3, mask); \
+ reduceRow(4, mask); \
+ reduceRow(5, mask); \
+ reduceRow(6, mask); \
+ reduceRow(7, mask); \
+
+ // actually perform the reduction, now each thread of index (_, y, z)
+ // contains the correct values in its registers that belong in the output
+ // block
+ reduceMatrix(1);
+ reduceMatrix(2);
+ reduceMatrix(4);
+
+#undef shuffleInc
+#undef reduceRow
+#undef reduceMatrix
+
+ // now we need to copy the 64 values into main memory. We can't split work
+ // among threads because all variables are in registers. There's 2 ways
+ // to do this:
+ // (1) have 1 thread do 64 writes from registers into global memory
+ // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
+ // each do 8 writes into global memory. We can just overwrite the shared
+ // memory from the problem we just solved.
+ // (2) is slightly faster than (1) due to less branching and more ILP
+
+ // TODO: won't yield much gain, but could just use currently unused shared mem
+ // and then we won't have to sync
+ // wait for shared mem to be out of use
+ __syncthreads();
+
+#define writeResultShmem(i, j) \
+ lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
+
+#define writeRow(i) \
+ writeResultShmem(i, 0); \
+ writeResultShmem(i, 1); \
+ writeResultShmem(i, 2); \
+ writeResultShmem(i, 3); \
+ writeResultShmem(i, 4); \
+ writeResultShmem(i, 5); \
+ writeResultShmem(i, 6); \
+ writeResultShmem(i, 7); \
+
+ if (threadIdx.x == 0) {
+ writeRow(0);
+ writeRow(1);
+ writeRow(2);
+ writeRow(3);
+ writeRow(4);
+ writeRow(5);
+ writeRow(6);
+ writeRow(7);
+ }
+#undef writeResultShmem
+#undef writeRow
+
+ const int max_i_write = (min)((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
+ const int max_j_write = (min)((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
+
+ if (threadIdx.x < max_i_write) {
+ if (max_j_write == 8) {
+ Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
+ Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
+ Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
+ Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
+ Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
+ Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
+ Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
+ Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
+
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
+ } else {
+#pragma unroll 7
+ for (int j = 0; j < max_j_write; j++) {
+ Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
+ output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
+ }
+ }
+ }
+#undef res
+ }
+
+
+ template<typename Scalar, typename Index, typename LhsMapper,
+ typename RhsMapper, typename OutputMapper>
+__global__ void
+__launch_bounds__(512)
+ EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
+ const OutputMapper output,
+ const Index m_size, const Index n_size, const Index k_size) {
+ __shared__ volatile Scalar lhs_shmem[72 * 64];
+ __shared__ volatile Scalar rhs_shmem[72 * 64];
+
+ const Index m_block_idx = blockIdx.x;
+ const Index n_block_idx = blockIdx.y;
+
+ const Index base_m = 64 * m_block_idx;
+ const Index base_n = 64 * n_block_idx;
+
+ if (base_m + 63 < m_size && base_n + 63 < n_size) {
+ EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
+ } else {
+ EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
+ }
+ }
+
+
+
+ template<typename Index, typename LhsMapper,
+ typename RhsMapper, typename OutputMapper, bool needs_edge_check>
+__device__ EIGEN_STRONG_INLINE void
+ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
+ const OutputMapper output, float4* lhs_shmem4, float2* rhs_shmem2,
+ const Index m_size, const Index n_size, const Index k_size) {
+ typedef float Scalar;
+
+ const Index m_block_idx = blockIdx.x;
+ const Index n_block_idx = blockIdx.y;
+
+ const Index base_m = 64 * m_block_idx;
+ const Index base_n = 64 * n_block_idx;
+
+ const Index lane = threadIdx.x + 8 * (threadIdx.y % 4);
+
+ // prefetch registers
+ float4 lhs_pf0;
+ float4 lhs_pf1;
+
+ float4 rhs_pf0;
+ float4 rhs_pf1;
+
+ // shared memory is formatted
+ // (contract idx in block, nocontract idx in block, block idx)
+ // where block idx is column major. This transposition limits the number of
+ // bank conflicts when reading the LHS. The core idea is that since the contracting
+ // index is shared by both sides, then the contracting index should be in threadIdx.x.
+
+ // all of these indices assume float4 loading
+ // this thread loads the float4 starting at this index, and then also loads
+ // another float4 starting 32 columns to to the right
+ const Index horiz_block_idx = threadIdx.z / 2;
+ const Index vert_block_idx = threadIdx.x / 2 + 4 * (threadIdx.y % 2);
+ const Index horiz_idx_in_block = threadIdx.y / 2 + 4 * (threadIdx.z % 2);
+ const Index vert_idx_in_block = threadIdx.x % 2;
+
+ // there's padding in both the LHS and RHS shared memory layouts. This padding
+ // allows for 0 bank conflicts on all shmem stores and loads.
+ // LHS padding: 1 float4 on each 8x8 block of floats
+ // RHS padding: 1 float2 on each block, and 12 additional float2s between vertical blocks
+ // 3 and 4
+
+ // storage indices
+ // lhs index with respect to float4s
+ const Index lhs_store_idx_base =
+ 136 * horiz_block_idx +
+ 17 * vert_block_idx +
+ 8 * vert_idx_in_block +
+ horiz_idx_in_block;
+
+ // rhs index with respect to floats
+ const Index rhs_store_idx_base =
+ 552 * horiz_block_idx +
+ 66 * vert_block_idx +
+ 32 * (horiz_idx_in_block / 4) + (horiz_idx_in_block % 4) +
+ 16 * vert_idx_in_block +
+ ((vert_block_idx < 4) ? 0 : 24);
+
+ const Index lhs_store_idx_0 = lhs_store_idx_base + 544 * 0;
+ const Index lhs_store_idx_1 = lhs_store_idx_base + 544 * 1;
+
+ const Index rhs_store_idx_0 = (rhs_store_idx_base / 2) + ((lane < 16) ? 0 : 4);
+ const Index rhs_store_idx_1 = rhs_store_idx_0 + 2;
+ const Index rhs_store_idx_2 = rhs_store_idx_0 + 1104;
+ const Index rhs_store_idx_3 = rhs_store_idx_1 + 1104;
+
+ // The below diagrams show which shmem index (with respect to floats) each element
+ // in an 8x8 input block gets packed into:
+ // LHS:
+ // 0 4 8 12 16 20 24 28
+ // 1 5 9 13 17 21 25 29
+ // 2 6 10 14 18 22 26 30
+ // 3 7 11 15 19 23 27 31
+ // 32 36 40 44 48 52 56 60
+ // ... (pack as 2 rows of float4 indexed row major, each float4 is vertical)
+ //
+ // RHS:
+ // 0 1 2 3 32 33 34 35
+ // 4 5 6 7 36 37 38 39
+ // ... (pack as 2 cols of float4 indexed col major, each float4 is horizontal)
+
+ // Each thread in a warp loads 2 float4s. This happens in 2 instructions. On each of these
+ // instruction, the warp loads 2 columns (2 cols * 64 elements / col = 128 elements = 32 threads
+ // * 4 elements/thread). For the LHS, we're able to store the loaded float4 directly into
+ // shmem (using a 128 bit store instruction). For the RHS, we need to transpose the data.
+ // This is done with warp shuffles. Furthermore, we only use 64 bit stores for the RHS, because
+ // 64 bits is only 2 columns (which is all we load in a warp), and the padding for the RHS
+ // doesn't meet 64 bit alignment requirements (namely, the 4 consecutive floats that we want
+ // to load on the RHS are 8 byte aligned, not 16 byte aligned, which is required for float4).
+
+ const Index load_idx_vert = 4 * (threadIdx.x + 8 * (threadIdx.y % 2));
+ const Index load_idx_horiz = (threadIdx.y / 2) + 4 * threadIdx.z;
+
+ const Index lhs_vert = base_m + load_idx_vert;
+ const Index rhs_horiz_0 = base_n + load_idx_horiz;
+ const Index rhs_horiz_1 = base_n + load_idx_horiz + 32;
+
+#define prefetchIntoRegisters(base_k) \
+ { \
+ lhs_pf0 = internal::pset1<float4>(0); \
+ lhs_pf1 = internal::pset1<float4>(0); \
+ \
+ rhs_pf0 = internal::pset1<float4>(0); \
+ rhs_pf1 = internal::pset1<float4>(0); \
+ \
+ const Index lhs_horiz_0 = base_k + load_idx_horiz; \
+ const Index lhs_horiz_1 = base_k + load_idx_horiz + 32; \
+ if (!needs_edge_check || lhs_vert + 3 < m_size) { \
+ if (lhs_horiz_1 < k_size) { \
+ lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \
+ lhs_pf1 = lhs.loadPacket(lhs_vert, lhs_horiz_1); \
+ } else if (lhs_horiz_0 < k_size) { \
+ lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \
+ } \
+ } else if (lhs_vert + 2 < m_size) { \
+ if (lhs_horiz_1 < k_size) { \
+ lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \
+ lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \
+ lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \
+ \
+ lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \
+ lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \
+ lhs_pf1.z = lhs(lhs_vert + 2, lhs_horiz_1); \
+ } else if (lhs_horiz_0 < k_size) { \
+ lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \
+ lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \
+ lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \
+ } \
+ } else if (lhs_vert + 1 < m_size) { \
+ if (lhs_horiz_1 < k_size) { \
+ lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \
+ lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \
+ \
+ lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \
+ lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \
+ } else if (lhs_horiz_0 < k_size) { \
+ lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \
+ lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \
+ } \
+ } else if (lhs_vert < m_size) { \
+ if (lhs_horiz_1 < k_size) { \
+ lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \
+ lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \
+ } else if (lhs_horiz_0 < k_size) { \
+ lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \
+ } \
+} \
+ \
+ const Index rhs_vert = base_k + load_idx_vert; \
+ if (rhs_vert + 3 < k_size) { \
+ if (!needs_edge_check || rhs_horiz_1 < n_size) { \
+ rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \
+ rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz_1); \
+ } else if (rhs_horiz_0 < n_size) { \
+ rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \
+ } \
+ } else if (rhs_vert + 2 < k_size) { \
+ if (!needs_edge_check || rhs_horiz_1 < n_size) { \
+ rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \
+ rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \
+ rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \
+ \
+ rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \
+ rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \
+ rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz_1); \
+ } else if (rhs_horiz_0 < n_size) { \
+ rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \
+ rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \
+ rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \
+ } \
+ } else if (rhs_vert + 1 < k_size) { \
+ if (!needs_edge_check || rhs_horiz_1 < n_size) { \
+ rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \
+ rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \
+ \
+ rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \
+ rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \
+ } else if (rhs_horiz_0 < n_size) { \
+ rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \
+ rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \
+ } \
+ } else if (rhs_vert < k_size) { \
+ if (!needs_edge_check || rhs_horiz_1 < n_size) { \
+ rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \
+ rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \
+ } else if (rhs_horiz_0 < n_size) { \
+ rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \
+ } \
+} \
+ \
+ float swap_val0 = (lane < 16) ? rhs_pf0.z : rhs_pf0.x; \
+ float swap_val1 = (lane < 16) ? rhs_pf0.w : rhs_pf0.y; \
+ float swap_val2 = (lane < 16) ? rhs_pf1.z : rhs_pf1.x; \
+ float swap_val3 = (lane < 16) ? rhs_pf1.w : rhs_pf1.y; \
+ \
+ swap_val0 = __shfl_xor(swap_val0, 16); \
+ swap_val1 = __shfl_xor(swap_val1, 16); \
+ swap_val2 = __shfl_xor(swap_val2, 16); \
+ swap_val3 = __shfl_xor(swap_val3, 16); \
+ \
+ if (lane < 16) { \
+ rhs_pf0.z = swap_val0; \
+ rhs_pf0.w = swap_val1; \
+ rhs_pf1.z = swap_val2; \
+ rhs_pf1.w = swap_val3; \
+ } else { \
+ rhs_pf0.x = swap_val0; \
+ rhs_pf0.y = swap_val1; \
+ rhs_pf1.x = swap_val2; \
+ rhs_pf1.y = swap_val3; \
+ } \
+} \
+
+
+#define writeRegToShmem(_) \
+ lhs_shmem4[lhs_store_idx_0] = lhs_pf0; \
+ \
+ rhs_shmem2[rhs_store_idx_0] = make_float2(rhs_pf0.x, rhs_pf0.z); \
+ rhs_shmem2[rhs_store_idx_1] = make_float2(rhs_pf0.y, rhs_pf0.w); \
+ \
+ lhs_shmem4[lhs_store_idx_1] = lhs_pf1; \
+ \
+ rhs_shmem2[rhs_store_idx_2] = make_float2(rhs_pf1.x, rhs_pf1.z); \
+ rhs_shmem2[rhs_store_idx_3] = make_float2(rhs_pf1.y, rhs_pf1.w); \
+
+ // declare and initialize result array
+#define res(i, j) _res_##i##j
+#define initResultRow(i) \
+ Scalar res(i, 0) = Scalar(0); \
+ Scalar res(i, 1) = Scalar(0); \
+ Scalar res(i, 2) = Scalar(0); \
+ Scalar res(i, 3) = Scalar(0); \
+ Scalar res(i, 4) = Scalar(0); \
+ Scalar res(i, 5) = Scalar(0); \
+ Scalar res(i, 6) = Scalar(0); \
+ Scalar res(i, 7) = Scalar(0); \
+
+ initResultRow(0);
+ initResultRow(1);
+ initResultRow(2);
+ initResultRow(3);
+ initResultRow(4);
+ initResultRow(5);
+ initResultRow(6);
+ initResultRow(7);
+#undef initResultRow
+
+ for (Index base_k = 0; base_k < k_size; base_k += 64) {
+ // wait for previous iteration to finish with shmem. Despite common sense,
+ // the code is a bit faster with this here then at bottom of loop
+ __syncthreads();
+
+ prefetchIntoRegisters(base_k);
+ writeRegToShmem();
+
+#undef prefetchIntoRegisters
+#undef writeRegoToShmem
+
+ // wait for shared mem packing to be done before starting computation
+ __syncthreads();
+
+ // compute 8x8 matrix product by outer product. This involves packing one column
+ // of LHS and one row of RHS into registers (takes 16 registers).
+
+ float4 _lcol0;
+ float4 _lcol1;
+ float2 _rrow0;
+ float2 _rrow1;
+ float2 _rrow2;
+ float2 _rrow3;
+
+#define lcol0 _lcol0.x
+#define lcol1 _lcol0.y
+#define lcol2 _lcol0.z
+#define lcol3 _lcol0.w
+#define lcol4 _lcol1.x
+#define lcol5 _lcol1.y
+#define lcol6 _lcol1.z
+#define lcol7 _lcol1.w
+#define rrow0 _rrow0.x
+#define rrow1 _rrow0.y
+#define rrow2 _rrow1.x
+#define rrow3 _rrow1.y
+#define rrow4 _rrow2.x
+#define rrow5 _rrow2.y
+#define rrow6 _rrow3.x
+#define rrow7 _rrow3.y
+
+ // Now x corresponds to k, y to m, and z to n
+ const float4* lhs_block = &lhs_shmem4[threadIdx.x + 8 * (threadIdx.y % 2) + 17 * (threadIdx.y / 2)];
+ const float2* rhs_block = &rhs_shmem2[2 * threadIdx.x + 16 * (threadIdx.z % 2) + 276 * (threadIdx.z / 2)];
+
+#define lhs_element(i, k) lhs_block[68 * i + 136 * k]
+#define rhs_element(k, j) rhs_block[33 * k + 1104 * j + ((k < 4) ? 0 : 12)]
+
+#define loadData(i) \
+ _lcol0 = lhs_element(0, i); \
+ _rrow0 = rhs_element(i, 0); \
+ _rrow1 = *(&(rhs_element(i, 0)) + 1); \
+ _lcol1 = lhs_element(1, i); \
+ _rrow2 = rhs_element(i, 1); \
+ _rrow3 = *(&(rhs_element(i, 1)) + 1); \
+
+#define computeCol(j) \
+ res(0, j) += lcol0 * rrow##j; \
+ res(1, j) += lcol1 * rrow##j; \
+ res(2, j) += lcol2 * rrow##j; \
+ res(3, j) += lcol3 * rrow##j; \
+ res(4, j) += lcol4 * rrow##j; \
+ res(5, j) += lcol5 * rrow##j; \
+ res(6, j) += lcol6 * rrow##j; \
+ res(7, j) += lcol7 * rrow##j; \
+
+#define computePass(i) \
+ loadData(i); \
+ \
+ computeCol(0); \
+ computeCol(1); \
+ computeCol(2); \
+ computeCol(3); \
+ computeCol(4); \
+ computeCol(5); \
+ computeCol(6); \
+ computeCol(7); \
+
+ computePass(0);
+ computePass(1);
+ computePass(2);
+ computePass(3);
+ computePass(4);
+ computePass(5);
+ computePass(6);
+ computePass(7);
+
+#undef lcol0
+#undef lcol1
+#undef lcol2
+#undef lcol3
+#undef lcol4
+#undef lcol5
+#undef lcol6
+#undef lcol7
+#undef rrow0
+#undef rrow1
+#undef rrow2
+#undef rrow3
+#undef rrow4
+#undef rrow5
+#undef rrow6
+#undef rrow7
+
+#undef computePass
+#undef computeCol
+#undef loadData
+#undef lhs_element
+#undef rhs_element
+
+ } // end loop over k
+
+ // we've now iterated over all of the large (ie width 64) k blocks and
+ // accumulated results in registers. At this point thread (x, y, z) contains
+ // the sum across all big k blocks of the product of little k block of index (x, y)
+ // with block of index (y, z). To compute the final output, we need to reduce
+ // the 8 threads over y by summation.
+#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
+
+#define reduceRow(i, mask) \
+ shuffleInc(i, 0, mask); \
+ shuffleInc(i, 1, mask); \
+ shuffleInc(i, 2, mask); \
+ shuffleInc(i, 3, mask); \
+ shuffleInc(i, 4, mask); \
+ shuffleInc(i, 5, mask); \
+ shuffleInc(i, 6, mask); \
+ shuffleInc(i, 7, mask); \
+
+#define reduceMatrix(mask) \
+ reduceRow(0, mask); \
+ reduceRow(1, mask); \
+ reduceRow(2, mask); \
+ reduceRow(3, mask); \
+ reduceRow(4, mask); \
+ reduceRow(5, mask); \
+ reduceRow(6, mask); \
+ reduceRow(7, mask); \
+
+ // actually perform the reduction, now each thread of index (_, y, z)
+ // contains the correct values in its registers that belong in the output
+ // block
+ reduceMatrix(1);
+ reduceMatrix(2);
+ reduceMatrix(4);
+
+#undef shuffleInc
+#undef reduceRow
+#undef reduceMatrix
+
+ // now we need to copy the 64 values into main memory. We can't split work
+ // among threads because all variables are in registers. There's 2 ways
+ // to do this:
+ // (1) have 1 thread do 64 writes from registers into global memory
+ // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
+ // each do 8 writes into global memory. We can just overwrite the shared
+ // memory from the problem we just solved.
+ // (3) Copies the values into new registers using conditional logic.
+
+#define makeAssignments(i) \
+ val0 = res(i, 0); \
+ val1 = res(i, 1); \
+ val2 = res(i, 2); \
+ val3 = res(i, 3); \
+ val4 = res(i, 4); \
+ val5 = res(i, 5); \
+ val6 = res(i, 6); \
+ val7 = res(i, 7); \
+
+ Scalar val0;
+ Scalar val1;
+ Scalar val2;
+ Scalar val3;
+ Scalar val4;
+ Scalar val5;
+ Scalar val6;
+ Scalar val7;
+
+ switch (threadIdx.x) {
+ case 0:
+ makeAssignments(0);
+ break;
+ case 1:
+ makeAssignments(1);
+ break;
+ case 2:
+ makeAssignments(2);
+ break;
+ case 3:
+ makeAssignments(3);
+ break;
+ case 4:
+ makeAssignments(4);
+ break;
+ case 5:
+ makeAssignments(5);
+ break;
+ case 6:
+ makeAssignments(6);
+ break;
+ case 7:
+ makeAssignments(7);
+ break;
+ }
+
+#undef res
+
+ const Index vert_base = base_m + 4 * threadIdx.y + (threadIdx.x % 4) + 32 * (threadIdx.x / 4);
+ const Index horiz_base = base_n + 4 * threadIdx.z;
+
+ if (!needs_edge_check || vert_base < m_size) {
+ if (!needs_edge_check || horiz_base + 35 < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ output(vert_base, horiz_base + 1) = val1;
+ output(vert_base, horiz_base + 2) = val2;
+ output(vert_base, horiz_base + 3) = val3;
+ output(vert_base, horiz_base + 32) = val4;
+ output(vert_base, horiz_base + 33) = val5;
+ output(vert_base, horiz_base + 34) = val6;
+ output(vert_base, horiz_base + 35) = val7;
+ } else if (horiz_base + 34 < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ output(vert_base, horiz_base + 1) = val1;
+ output(vert_base, horiz_base + 2) = val2;
+ output(vert_base, horiz_base + 3) = val3;
+ output(vert_base, horiz_base + 32) = val4;
+ output(vert_base, horiz_base + 33) = val5;
+ output(vert_base, horiz_base + 34) = val6;
+ } else if (horiz_base + 33 < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ output(vert_base, horiz_base + 1) = val1;
+ output(vert_base, horiz_base + 2) = val2;
+ output(vert_base, horiz_base + 3) = val3;
+ output(vert_base, horiz_base + 32) = val4;
+ output(vert_base, horiz_base + 33) = val5;
+ } else if (horiz_base + 32 < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ output(vert_base, horiz_base + 1) = val1;
+ output(vert_base, horiz_base + 2) = val2;
+ output(vert_base, horiz_base + 3) = val3;
+ output(vert_base, horiz_base + 32) = val4;
+ } else if (horiz_base + 3 < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ output(vert_base, horiz_base + 1) = val1;
+ output(vert_base, horiz_base + 2) = val2;
+ output(vert_base, horiz_base + 3) = val3;
+ } else if (horiz_base + 2 < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ output(vert_base, horiz_base + 1) = val1;
+ output(vert_base, horiz_base + 2) = val2;
+ } else if (horiz_base + 1 < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ output(vert_base, horiz_base + 1) = val1;
+ } else if (horiz_base < n_size) {
+ output(vert_base, horiz_base + 0) = val0;
+ }
+ }
+ }
+
+
+ template<typename Index, typename LhsMapper,
+ typename RhsMapper, typename OutputMapper>
+__global__ void
+ __launch_bounds__(512)
+ EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
+ const OutputMapper output,
+ const Index m_size, const Index n_size, const Index k_size) {
+ __shared__ float4 lhs_shmem[(68 * 64) / 4];
+ __shared__ float2 rhs_shmem[((66 * 8 + 24) * 8) / 2];
+
+ const Index m_block_idx = blockIdx.x;
+ const Index n_block_idx = blockIdx.y;
+
+ const Index base_m = 64 * m_block_idx;
+ const Index base_n = 64 * n_block_idx;
+
+ if (base_m + 63 < m_size && base_n + 63 < n_size) {
+ EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
+ } else {
+ EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
+ }
+ }
+
+
+ template<typename Indices, typename LeftArgType, typename RightArgType>
+ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> :
+ public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > {
+
+ typedef GpuDevice Device;
+
+ typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
+ typedef TensorContractionEvaluatorBase<Self> Base;
+
+ typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
+ typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
+ typedef typename XprType::Packet Packet;
+ typedef typename XprType::Index Index;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count> left_dim_mapper_t;
+ typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count> right_dim_mapper_t;
+
+ typedef array<Index, internal::array_size<Indices>::value> contract_t;
+ typedef array<Index, TensorEvaluator<LeftArgType, Device>::Dimensions::count - internal::array_size<Indices>::value> left_nocontract_t;
+ typedef array<Index, TensorEvaluator<RightArgType, Device>::Dimensions::count - internal::array_size<Indices>::value> right_nocontract_t;
+
+ static const int NumDims = max_n_1<TensorEvaluator<LeftArgType, Device>::Dimensions::count + TensorEvaluator<RightArgType, Device>::Dimensions::count - 2 * internal::array_size<Indices>::value>::size;
+
+ typedef DSizes<Index, NumDims> Dimensions;
+
+ // typedefs needed in evalTo
+ typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar;
+ typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar;
+
+ typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator;
+ typedef TensorEvaluator<RightArgType, Device> RightEvaluator;
+
+ typedef typename LeftEvaluator::Dimensions LeftDimensions;
+ typedef typename RightEvaluator::Dimensions RightDimensions;
+
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
+ Base(op, device) {}
+
+ // We need to redefine this method to make nvcc happy
+ EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
+ this->m_leftImpl.evalSubExprsIfNeeded(NULL);
+ this->m_rightImpl.evalSubExprsIfNeeded(NULL);
+ if (data) {
+ evalTo(data);
+ return false;
+ } else {
+ this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
+ evalTo(this->m_result);
+ return true;
+ }
+ }
+
+ void evalTo(Scalar* buffer) const {
+ if (this->m_lhs_inner_dim_contiguous) {
+ if (this->m_rhs_inner_dim_contiguous) {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<true, true, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<true, true, false, Unaligned>(buffer);
+ }
+ }
+ else {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<true, false, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<true, false, false, Unaligned>(buffer);
+ }
+ }
+ }
+ else {
+ if (this->m_rhs_inner_dim_contiguous) {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<false, true, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<false, true, false, Unaligned>(buffer);
+ }
+ }
+ else {
+ if (this->m_rhs_inner_dim_reordered) {
+ evalTyped<false, false, true, Unaligned>(buffer);
+ }
+ else {
+ evalTyped<false, false, false, Unaligned>(buffer);
+ }
+ }
+ }
+ }
+
+ template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+ void evalTyped(Scalar* buffer) const {
+ // columns in left side, rows in right side
+ const Index k = this->m_k_size;
+
+ // rows in left side
+ const Index m = this->m_i_size;
+
+ // columns in right side
+ const Index n = this->m_j_size;
+
+ // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
+ this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
+
+ typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
+ LeftEvaluator, left_nocontract_t,
+ contract_t, 4,
+ lhs_inner_dim_contiguous,
+ false, Unaligned> LhsMapper;
+
+ typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
+ RightEvaluator, right_nocontract_t,
+ contract_t, 4,
+ rhs_inner_dim_contiguous,
+ rhs_inner_dim_reordered, Unaligned> RhsMapper;
+
+ typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
+
+
+ // initialize data mappers
+ LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
+ this->m_left_contracting_strides, this->m_k_strides);
+
+ RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
+ this->m_right_contracting_strides, this->m_k_strides);
+
+ OutputMapper output(buffer, m);
+
+ const Index m_blocks = (m + 63) / 64;
+ const Index n_blocks = (n + 63) / 64;
+ const dim3 num_blocks(m_blocks, n_blocks, 1);
+ const dim3 block_size(8, 8, 8);
+
+ cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
+ if (internal::is_same<LhsScalar, float>::value &&
+ internal::is_same<RhsScalar, float>::value) {
+ EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>
+ <<<num_blocks, block_size, 0, this->m_device.stream()>>>(lhs, rhs, output, m, n, k);
+ } else {
+ EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>
+ <<<num_blocks, block_size, 0, this->m_device.stream()>>>(lhs, rhs, output, m, n, k);
+ }
+
+ assert(cudaGetLastError() == cudaSuccess);
+ }
+ };
+
+} // end namespace Eigen
+
+#endif // EIGEN_USE_GPU and __CUDACC__
+
+#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H