From af2e5995e2ba48384024bbc8432bd6dbbebf71d2 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 3 Oct 2014 19:18:07 -0700 Subject: Improved support for CUDA devices. Improved contractions on GPU --- .../Eigen/CXX11/src/Tensor/TensorContractionCuda.h | 1206 ++++++++++++++++++++ 1 file changed, 1206 insertions(+) create mode 100644 unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h new file mode 100644 index 000000000..babe33fff --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h @@ -0,0 +1,1206 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H +#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H + +#if defined(EIGEN_USE_GPU) && defined(__CUDACC__) + +namespace Eigen { + +template +__device__ EIGEN_STRONG_INLINE void +EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, volatile Scalar* lhs_shmem, volatile Scalar* rhs_shmem, + const Index m_size, const Index n_size, const Index k_size) { + + const Index m_block_idx = blockIdx.x; + const Index n_block_idx = blockIdx.y; + + const Index base_m = 64 * m_block_idx; + const Index base_n = 64 * n_block_idx; + + // declare and initialize 64 registers for output 8x8 block + + // prefetch registers + Scalar lhs_pf0; + Scalar lhs_pf1; + Scalar lhs_pf2; + Scalar lhs_pf3; + Scalar lhs_pf4; + Scalar lhs_pf5; + Scalar lhs_pf6; + Scalar lhs_pf7; + + Scalar rhs_pf0; + Scalar rhs_pf1; + Scalar rhs_pf2; + Scalar rhs_pf3; + Scalar rhs_pf4; + Scalar rhs_pf5; + Scalar rhs_pf6; + Scalar rhs_pf7; + + // shared memory is formatted + // (contract idx in block, nocontract idx in block, block idx) + // where block idx is column major. This transposition limits the number of + // bank conflicts when reading the LHS. The core idea is that since the contracting + // index is shared by both sides, then the contracting index should be in threadIdx.x. + + // On the LHS, we pad each row inside of each block with an extra element. This makes + // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts + // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks. + + // On the RHS we just add 8 padding elements to the end of each block. This gives no bank + // conflicts on writes and also none on reads. + + // storage indices + const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z; + const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x; + + const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0; + const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1; + const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2; + const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3; + const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4; + const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5; + const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6; + const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7; + + const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0; + const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1; + const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2; + const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3; + const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4; + const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5; + const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6; + const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7; + + // in the loading code, the following variables are important: + // threadIdx.x: the vertical position in an 8x8 block + // threadIdx.y: the vertical index of the 8x8 block in the grid + // threadIdx.z: the horizontal position in an 8x8 block + // k: the horizontal index of the 8x8 block in the grid + // + // The k parameter is implicit (it was the loop counter for a loop that went + // from 0 to <8, but now that loop is unrolled in the below code. + + const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y; + const Index lhs_vert = base_m + load_idx_vert; + +#define prefetchIntoRegisters(base_k) \ + { \ + lhs_pf0 = Scalar(0); \ + lhs_pf1 = Scalar(0); \ + lhs_pf2 = Scalar(0); \ + lhs_pf3 = Scalar(0); \ + lhs_pf4 = Scalar(0); \ + lhs_pf5 = Scalar(0); \ + lhs_pf6 = Scalar(0); \ + lhs_pf7 = Scalar(0); \ + \ + rhs_pf0 = Scalar(0); \ + rhs_pf1 = Scalar(0); \ + rhs_pf2 = Scalar(0); \ + rhs_pf3 = Scalar(0); \ + rhs_pf4 = Scalar(0); \ + rhs_pf5 = Scalar(0); \ + rhs_pf6 = Scalar(0); \ + rhs_pf7 = Scalar(0); \ + \ + if (!needs_edge_check || lhs_vert < m_size) { \ + const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \ + const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \ + const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \ + const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \ + const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \ + const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \ + const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \ + const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \ + \ + if (!needs_edge_check || lhs_horiz_7 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ + lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ + lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ + lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ + lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ + lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ + lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \ + } else if (lhs_horiz_6 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ + lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ + lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ + lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ + lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ + lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ + } else if (lhs_horiz_5 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ + lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ + lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ + lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ + lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ + } else if (lhs_horiz_4 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ + lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ + lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ + lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ + } else if (lhs_horiz_3 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ + lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ + lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ + } else if (lhs_horiz_2 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ + lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ + } else if (lhs_horiz_1 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ + } else if (lhs_horiz_0 < k_size) { \ + lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ + } \ + } \ + \ + const Index rhs_vert = base_k + load_idx_vert; \ + if (!needs_edge_check || rhs_vert < k_size) { \ + const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \ + const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \ + const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \ + const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \ + const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \ + const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \ + const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \ + const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \ + \ + if (rhs_horiz_7 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ + rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ + rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ + rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ + rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ + rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ + rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \ + } else if (rhs_horiz_6 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ + rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ + rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ + rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ + rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ + rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ + } else if (rhs_horiz_5 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ + rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ + rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ + rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ + rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ + } else if (rhs_horiz_4 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ + rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ + rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ + rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ + } else if (rhs_horiz_3 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ + rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ + rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ + } else if (rhs_horiz_2 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ + rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ + } else if (rhs_horiz_1 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ + } else if (rhs_horiz_0 < n_size) { \ + rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ + } \ + } \ + } \ + +#define writeRegToShmem(_) \ + lhs_shmem[lhs_store_idx_0] = lhs_pf0; \ + rhs_shmem[rhs_store_idx_0] = rhs_pf0; \ + \ + lhs_shmem[lhs_store_idx_1] = lhs_pf1; \ + rhs_shmem[rhs_store_idx_1] = rhs_pf1; \ + \ + lhs_shmem[lhs_store_idx_2] = lhs_pf2; \ + rhs_shmem[rhs_store_idx_2] = rhs_pf2; \ + \ + lhs_shmem[lhs_store_idx_3] = lhs_pf3; \ + rhs_shmem[rhs_store_idx_3] = rhs_pf3; \ + \ + lhs_shmem[lhs_store_idx_4] = lhs_pf4; \ + rhs_shmem[rhs_store_idx_4] = rhs_pf4; \ + \ + lhs_shmem[lhs_store_idx_5] = lhs_pf5; \ + rhs_shmem[rhs_store_idx_5] = rhs_pf5; \ + \ + lhs_shmem[lhs_store_idx_6] = lhs_pf6; \ + rhs_shmem[rhs_store_idx_6] = rhs_pf6; \ + \ + lhs_shmem[lhs_store_idx_7] = lhs_pf7; \ + rhs_shmem[rhs_store_idx_7] = rhs_pf7; \ + + // declare and initialize result array +#define res(i, j) _res_##i##j +#define initResultRow(i) \ + Scalar res(i, 0) = Scalar(0); \ + Scalar res(i, 1) = Scalar(0); \ + Scalar res(i, 2) = Scalar(0); \ + Scalar res(i, 3) = Scalar(0); \ + Scalar res(i, 4) = Scalar(0); \ + Scalar res(i, 5) = Scalar(0); \ + Scalar res(i, 6) = Scalar(0); \ + Scalar res(i, 7) = Scalar(0); \ + + initResultRow(0); + initResultRow(1); + initResultRow(2); + initResultRow(3); + initResultRow(4); + initResultRow(5); + initResultRow(6); + initResultRow(7); +#undef initResultRow + + for (Index base_k = 0; base_k < k_size; base_k += 64) { + // wait for previous iteration to finish with shmem. Despite common sense, + // the code is a bit faster with this here then at bottom of loop + __syncthreads(); + + prefetchIntoRegisters(base_k); + writeRegToShmem(); + + #undef prefetchIntoRegisters + #undef writeRegToShmem + + // wait for shared mem packing to be done before starting computation + __syncthreads(); + + // compute 8x8 matrix product by outer product. This involves packing one column + // of LHS and one row of RHS into registers (takes 16 registers). + +#define lcol(i) _lcol##i + Scalar lcol(0); + Scalar lcol(1); + Scalar lcol(2); + Scalar lcol(3); + Scalar lcol(4); + Scalar lcol(5); + Scalar lcol(6); + Scalar lcol(7); + +#define rrow(j) _rrow##j + Scalar rrow(0); + Scalar rrow(1); + Scalar rrow(2); + Scalar rrow(3); + Scalar rrow(4); + Scalar rrow(5); + Scalar rrow(6); + Scalar rrow(7); + + // Now x corresponds to k, y to m, and z to n + const volatile Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y]; + const volatile Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z]; + +#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] +#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] + +#define loadData(i, j) \ + lcol(0) = lhs_element(0, j); \ + rrow(0) = rhs_element(i, 0); \ + lcol(1) = lhs_element(1, j); \ + rrow(1) = rhs_element(i, 1); \ + lcol(2) = lhs_element(2, j); \ + rrow(2) = rhs_element(i, 2); \ + lcol(3) = lhs_element(3, j); \ + rrow(3) = rhs_element(i, 3); \ + lcol(4) = lhs_element(4, j); \ + rrow(4) = rhs_element(i, 4); \ + lcol(5) = lhs_element(5, j); \ + rrow(5) = rhs_element(i, 5); \ + lcol(6) = lhs_element(6, j); \ + rrow(6) = rhs_element(i, 6); \ + lcol(7) = lhs_element(7, j); \ + rrow(7) = rhs_element(i, 7); \ + +#define computeCol(j) \ + res(0, j) += lcol(0) * rrow(j); \ + res(1, j) += lcol(1) * rrow(j); \ + res(2, j) += lcol(2) * rrow(j); \ + res(3, j) += lcol(3) * rrow(j); \ + res(4, j) += lcol(4) * rrow(j); \ + res(5, j) += lcol(5) * rrow(j); \ + res(6, j) += lcol(6) * rrow(j); \ + res(7, j) += lcol(7) * rrow(j); \ + +#define computePass(i) \ + loadData(i, i); \ + \ + computeCol(0); \ + computeCol(1); \ + computeCol(2); \ + computeCol(3); \ + computeCol(4); \ + computeCol(5); \ + computeCol(6); \ + computeCol(7); \ + + computePass(0); + computePass(1); + computePass(2); + computePass(3); + computePass(4); + computePass(5); + computePass(6); + computePass(7); + +#undef lcol +#undef rrow +#undef lhs_element +#undef rhs_element +#undef loadData +#undef computeCol +#undef computePass + } // end loop over k + + // we've now iterated over all of the large (ie width 64) k blocks and + // accumulated results in registers. At this point thread (x, y, z) contains + // the sum across all big k blocks of the product of little k block of index (x, y) + // with block of index (y, z). To compute the final output, we need to reduce + // the 8 threads over y by summation. +#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) + +#define reduceRow(i, mask) \ + shuffleInc(i, 0, mask); \ + shuffleInc(i, 1, mask); \ + shuffleInc(i, 2, mask); \ + shuffleInc(i, 3, mask); \ + shuffleInc(i, 4, mask); \ + shuffleInc(i, 5, mask); \ + shuffleInc(i, 6, mask); \ + shuffleInc(i, 7, mask); \ + +#define reduceMatrix(mask) \ + reduceRow(0, mask); \ + reduceRow(1, mask); \ + reduceRow(2, mask); \ + reduceRow(3, mask); \ + reduceRow(4, mask); \ + reduceRow(5, mask); \ + reduceRow(6, mask); \ + reduceRow(7, mask); \ + + // actually perform the reduction, now each thread of index (_, y, z) + // contains the correct values in its registers that belong in the output + // block + reduceMatrix(1); + reduceMatrix(2); + reduceMatrix(4); + +#undef shuffleInc +#undef reduceRow +#undef reduceMatrix + + // now we need to copy the 64 values into main memory. We can't split work + // among threads because all variables are in registers. There's 2 ways + // to do this: + // (1) have 1 thread do 64 writes from registers into global memory + // (2) have 1 thread do 64 writes into shared memory, and then 8 threads + // each do 8 writes into global memory. We can just overwrite the shared + // memory from the problem we just solved. + // (2) is slightly faster than (1) due to less branching and more ILP + + // TODO: won't yield much gain, but could just use currently unused shared mem + // and then we won't have to sync + // wait for shared mem to be out of use + __syncthreads(); + +#define writeResultShmem(i, j) \ + lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \ + +#define writeRow(i) \ + writeResultShmem(i, 0); \ + writeResultShmem(i, 1); \ + writeResultShmem(i, 2); \ + writeResultShmem(i, 3); \ + writeResultShmem(i, 4); \ + writeResultShmem(i, 5); \ + writeResultShmem(i, 6); \ + writeResultShmem(i, 7); \ + + if (threadIdx.x == 0) { + writeRow(0); + writeRow(1); + writeRow(2); + writeRow(3); + writeRow(4); + writeRow(5); + writeRow(6); + writeRow(7); + } +#undef writeResultShmem +#undef writeRow + + const int max_i_write = (min)((int)((m_size - base_m - threadIdx.y + 7) / 8), 8); + const int max_j_write = (min)((int)((n_size - base_n - threadIdx.z + 7) / 8), 8); + + if (threadIdx.x < max_i_write) { + if (max_j_write == 8) { + Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0]; + Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1]; + Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2]; + Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3]; + Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4]; + Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5]; + Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6]; + Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7]; + + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7; + } else { +#pragma unroll 7 + for (int j = 0; j < max_j_write; j++) { + Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j]; + output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val; + } + } + } +#undef res + } + + + template +__global__ void +__launch_bounds__(512) + EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, + const Index m_size, const Index n_size, const Index k_size) { + __shared__ volatile Scalar lhs_shmem[72 * 64]; + __shared__ volatile Scalar rhs_shmem[72 * 64]; + + const Index m_block_idx = blockIdx.x; + const Index n_block_idx = blockIdx.y; + + const Index base_m = 64 * m_block_idx; + const Index base_n = 64 * n_block_idx; + + if (base_m + 63 < m_size && base_n + 63 < n_size) { + EigenContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + } else { + EigenContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + } + } + + + + template +__device__ EIGEN_STRONG_INLINE void + EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, float4* lhs_shmem4, float2* rhs_shmem2, + const Index m_size, const Index n_size, const Index k_size) { + typedef float Scalar; + + const Index m_block_idx = blockIdx.x; + const Index n_block_idx = blockIdx.y; + + const Index base_m = 64 * m_block_idx; + const Index base_n = 64 * n_block_idx; + + const Index lane = threadIdx.x + 8 * (threadIdx.y % 4); + + // prefetch registers + float4 lhs_pf0; + float4 lhs_pf1; + + float4 rhs_pf0; + float4 rhs_pf1; + + // shared memory is formatted + // (contract idx in block, nocontract idx in block, block idx) + // where block idx is column major. This transposition limits the number of + // bank conflicts when reading the LHS. The core idea is that since the contracting + // index is shared by both sides, then the contracting index should be in threadIdx.x. + + // all of these indices assume float4 loading + // this thread loads the float4 starting at this index, and then also loads + // another float4 starting 32 columns to to the right + const Index horiz_block_idx = threadIdx.z / 2; + const Index vert_block_idx = threadIdx.x / 2 + 4 * (threadIdx.y % 2); + const Index horiz_idx_in_block = threadIdx.y / 2 + 4 * (threadIdx.z % 2); + const Index vert_idx_in_block = threadIdx.x % 2; + + // there's padding in both the LHS and RHS shared memory layouts. This padding + // allows for 0 bank conflicts on all shmem stores and loads. + // LHS padding: 1 float4 on each 8x8 block of floats + // RHS padding: 1 float2 on each block, and 12 additional float2s between vertical blocks + // 3 and 4 + + // storage indices + // lhs index with respect to float4s + const Index lhs_store_idx_base = + 136 * horiz_block_idx + + 17 * vert_block_idx + + 8 * vert_idx_in_block + + horiz_idx_in_block; + + // rhs index with respect to floats + const Index rhs_store_idx_base = + 552 * horiz_block_idx + + 66 * vert_block_idx + + 32 * (horiz_idx_in_block / 4) + (horiz_idx_in_block % 4) + + 16 * vert_idx_in_block + + ((vert_block_idx < 4) ? 0 : 24); + + const Index lhs_store_idx_0 = lhs_store_idx_base + 544 * 0; + const Index lhs_store_idx_1 = lhs_store_idx_base + 544 * 1; + + const Index rhs_store_idx_0 = (rhs_store_idx_base / 2) + ((lane < 16) ? 0 : 4); + const Index rhs_store_idx_1 = rhs_store_idx_0 + 2; + const Index rhs_store_idx_2 = rhs_store_idx_0 + 1104; + const Index rhs_store_idx_3 = rhs_store_idx_1 + 1104; + + // The below diagrams show which shmem index (with respect to floats) each element + // in an 8x8 input block gets packed into: + // LHS: + // 0 4 8 12 16 20 24 28 + // 1 5 9 13 17 21 25 29 + // 2 6 10 14 18 22 26 30 + // 3 7 11 15 19 23 27 31 + // 32 36 40 44 48 52 56 60 + // ... (pack as 2 rows of float4 indexed row major, each float4 is vertical) + // + // RHS: + // 0 1 2 3 32 33 34 35 + // 4 5 6 7 36 37 38 39 + // ... (pack as 2 cols of float4 indexed col major, each float4 is horizontal) + + // Each thread in a warp loads 2 float4s. This happens in 2 instructions. On each of these + // instruction, the warp loads 2 columns (2 cols * 64 elements / col = 128 elements = 32 threads + // * 4 elements/thread). For the LHS, we're able to store the loaded float4 directly into + // shmem (using a 128 bit store instruction). For the RHS, we need to transpose the data. + // This is done with warp shuffles. Furthermore, we only use 64 bit stores for the RHS, because + // 64 bits is only 2 columns (which is all we load in a warp), and the padding for the RHS + // doesn't meet 64 bit alignment requirements (namely, the 4 consecutive floats that we want + // to load on the RHS are 8 byte aligned, not 16 byte aligned, which is required for float4). + + const Index load_idx_vert = 4 * (threadIdx.x + 8 * (threadIdx.y % 2)); + const Index load_idx_horiz = (threadIdx.y / 2) + 4 * threadIdx.z; + + const Index lhs_vert = base_m + load_idx_vert; + const Index rhs_horiz_0 = base_n + load_idx_horiz; + const Index rhs_horiz_1 = base_n + load_idx_horiz + 32; + +#define prefetchIntoRegisters(base_k) \ + { \ + lhs_pf0 = internal::pset1(0); \ + lhs_pf1 = internal::pset1(0); \ + \ + rhs_pf0 = internal::pset1(0); \ + rhs_pf1 = internal::pset1(0); \ + \ + const Index lhs_horiz_0 = base_k + load_idx_horiz; \ + const Index lhs_horiz_1 = base_k + load_idx_horiz + 32; \ + if (!needs_edge_check || lhs_vert + 3 < m_size) { \ + if (lhs_horiz_1 < k_size) { \ + lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \ + lhs_pf1 = lhs.loadPacket(lhs_vert, lhs_horiz_1); \ + } else if (lhs_horiz_0 < k_size) { \ + lhs_pf0 = lhs.loadPacket(lhs_vert, lhs_horiz_0); \ + } \ + } else if (lhs_vert + 2 < m_size) { \ + if (lhs_horiz_1 < k_size) { \ + lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ + lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ + lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \ + \ + lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ + lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \ + lhs_pf1.z = lhs(lhs_vert + 2, lhs_horiz_1); \ + } else if (lhs_horiz_0 < k_size) { \ + lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ + lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ + lhs_pf0.z = lhs(lhs_vert + 2, lhs_horiz_0); \ + } \ + } else if (lhs_vert + 1 < m_size) { \ + if (lhs_horiz_1 < k_size) { \ + lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ + lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ + \ + lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ + lhs_pf1.y = lhs(lhs_vert + 1, lhs_horiz_1); \ + } else if (lhs_horiz_0 < k_size) { \ + lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ + lhs_pf0.y = lhs(lhs_vert + 1, lhs_horiz_0); \ + } \ + } else if (lhs_vert < m_size) { \ + if (lhs_horiz_1 < k_size) { \ + lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ + lhs_pf1.x = lhs(lhs_vert + 0, lhs_horiz_1); \ + } else if (lhs_horiz_0 < k_size) { \ + lhs_pf0.x = lhs(lhs_vert + 0, lhs_horiz_0); \ + } \ +} \ + \ + const Index rhs_vert = base_k + load_idx_vert; \ + if (rhs_vert + 3 < k_size) { \ + if (!needs_edge_check || rhs_horiz_1 < n_size) { \ + rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \ + rhs_pf1 = rhs.loadPacket(rhs_vert, rhs_horiz_1); \ + } else if (rhs_horiz_0 < n_size) { \ + rhs_pf0 = rhs.loadPacket(rhs_vert, rhs_horiz_0); \ + } \ + } else if (rhs_vert + 2 < k_size) { \ + if (!needs_edge_check || rhs_horiz_1 < n_size) { \ + rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ + rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \ + \ + rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ + rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \ + rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz_1); \ + } else if (rhs_horiz_0 < n_size) { \ + rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ + rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz_0); \ + } \ + } else if (rhs_vert + 1 < k_size) { \ + if (!needs_edge_check || rhs_horiz_1 < n_size) { \ + rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ + \ + rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ + rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz_1); \ + } else if (rhs_horiz_0 < n_size) { \ + rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ + rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz_0); \ + } \ + } else if (rhs_vert < k_size) { \ + if (!needs_edge_check || rhs_horiz_1 < n_size) { \ + rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ + rhs_pf1.x = rhs(rhs_vert + 0, rhs_horiz_1); \ + } else if (rhs_horiz_0 < n_size) { \ + rhs_pf0.x = rhs(rhs_vert + 0, rhs_horiz_0); \ + } \ +} \ + \ + float swap_val0 = (lane < 16) ? rhs_pf0.z : rhs_pf0.x; \ + float swap_val1 = (lane < 16) ? rhs_pf0.w : rhs_pf0.y; \ + float swap_val2 = (lane < 16) ? rhs_pf1.z : rhs_pf1.x; \ + float swap_val3 = (lane < 16) ? rhs_pf1.w : rhs_pf1.y; \ + \ + swap_val0 = __shfl_xor(swap_val0, 16); \ + swap_val1 = __shfl_xor(swap_val1, 16); \ + swap_val2 = __shfl_xor(swap_val2, 16); \ + swap_val3 = __shfl_xor(swap_val3, 16); \ + \ + if (lane < 16) { \ + rhs_pf0.z = swap_val0; \ + rhs_pf0.w = swap_val1; \ + rhs_pf1.z = swap_val2; \ + rhs_pf1.w = swap_val3; \ + } else { \ + rhs_pf0.x = swap_val0; \ + rhs_pf0.y = swap_val1; \ + rhs_pf1.x = swap_val2; \ + rhs_pf1.y = swap_val3; \ + } \ +} \ + + +#define writeRegToShmem(_) \ + lhs_shmem4[lhs_store_idx_0] = lhs_pf0; \ + \ + rhs_shmem2[rhs_store_idx_0] = make_float2(rhs_pf0.x, rhs_pf0.z); \ + rhs_shmem2[rhs_store_idx_1] = make_float2(rhs_pf0.y, rhs_pf0.w); \ + \ + lhs_shmem4[lhs_store_idx_1] = lhs_pf1; \ + \ + rhs_shmem2[rhs_store_idx_2] = make_float2(rhs_pf1.x, rhs_pf1.z); \ + rhs_shmem2[rhs_store_idx_3] = make_float2(rhs_pf1.y, rhs_pf1.w); \ + + // declare and initialize result array +#define res(i, j) _res_##i##j +#define initResultRow(i) \ + Scalar res(i, 0) = Scalar(0); \ + Scalar res(i, 1) = Scalar(0); \ + Scalar res(i, 2) = Scalar(0); \ + Scalar res(i, 3) = Scalar(0); \ + Scalar res(i, 4) = Scalar(0); \ + Scalar res(i, 5) = Scalar(0); \ + Scalar res(i, 6) = Scalar(0); \ + Scalar res(i, 7) = Scalar(0); \ + + initResultRow(0); + initResultRow(1); + initResultRow(2); + initResultRow(3); + initResultRow(4); + initResultRow(5); + initResultRow(6); + initResultRow(7); +#undef initResultRow + + for (Index base_k = 0; base_k < k_size; base_k += 64) { + // wait for previous iteration to finish with shmem. Despite common sense, + // the code is a bit faster with this here then at bottom of loop + __syncthreads(); + + prefetchIntoRegisters(base_k); + writeRegToShmem(); + +#undef prefetchIntoRegisters +#undef writeRegoToShmem + + // wait for shared mem packing to be done before starting computation + __syncthreads(); + + // compute 8x8 matrix product by outer product. This involves packing one column + // of LHS and one row of RHS into registers (takes 16 registers). + + float4 _lcol0; + float4 _lcol1; + float2 _rrow0; + float2 _rrow1; + float2 _rrow2; + float2 _rrow3; + +#define lcol0 _lcol0.x +#define lcol1 _lcol0.y +#define lcol2 _lcol0.z +#define lcol3 _lcol0.w +#define lcol4 _lcol1.x +#define lcol5 _lcol1.y +#define lcol6 _lcol1.z +#define lcol7 _lcol1.w +#define rrow0 _rrow0.x +#define rrow1 _rrow0.y +#define rrow2 _rrow1.x +#define rrow3 _rrow1.y +#define rrow4 _rrow2.x +#define rrow5 _rrow2.y +#define rrow6 _rrow3.x +#define rrow7 _rrow3.y + + // Now x corresponds to k, y to m, and z to n + const float4* lhs_block = &lhs_shmem4[threadIdx.x + 8 * (threadIdx.y % 2) + 17 * (threadIdx.y / 2)]; + const float2* rhs_block = &rhs_shmem2[2 * threadIdx.x + 16 * (threadIdx.z % 2) + 276 * (threadIdx.z / 2)]; + +#define lhs_element(i, k) lhs_block[68 * i + 136 * k] +#define rhs_element(k, j) rhs_block[33 * k + 1104 * j + ((k < 4) ? 0 : 12)] + +#define loadData(i) \ + _lcol0 = lhs_element(0, i); \ + _rrow0 = rhs_element(i, 0); \ + _rrow1 = *(&(rhs_element(i, 0)) + 1); \ + _lcol1 = lhs_element(1, i); \ + _rrow2 = rhs_element(i, 1); \ + _rrow3 = *(&(rhs_element(i, 1)) + 1); \ + +#define computeCol(j) \ + res(0, j) += lcol0 * rrow##j; \ + res(1, j) += lcol1 * rrow##j; \ + res(2, j) += lcol2 * rrow##j; \ + res(3, j) += lcol3 * rrow##j; \ + res(4, j) += lcol4 * rrow##j; \ + res(5, j) += lcol5 * rrow##j; \ + res(6, j) += lcol6 * rrow##j; \ + res(7, j) += lcol7 * rrow##j; \ + +#define computePass(i) \ + loadData(i); \ + \ + computeCol(0); \ + computeCol(1); \ + computeCol(2); \ + computeCol(3); \ + computeCol(4); \ + computeCol(5); \ + computeCol(6); \ + computeCol(7); \ + + computePass(0); + computePass(1); + computePass(2); + computePass(3); + computePass(4); + computePass(5); + computePass(6); + computePass(7); + +#undef lcol0 +#undef lcol1 +#undef lcol2 +#undef lcol3 +#undef lcol4 +#undef lcol5 +#undef lcol6 +#undef lcol7 +#undef rrow0 +#undef rrow1 +#undef rrow2 +#undef rrow3 +#undef rrow4 +#undef rrow5 +#undef rrow6 +#undef rrow7 + +#undef computePass +#undef computeCol +#undef loadData +#undef lhs_element +#undef rhs_element + + } // end loop over k + + // we've now iterated over all of the large (ie width 64) k blocks and + // accumulated results in registers. At this point thread (x, y, z) contains + // the sum across all big k blocks of the product of little k block of index (x, y) + // with block of index (y, z). To compute the final output, we need to reduce + // the 8 threads over y by summation. +#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) + +#define reduceRow(i, mask) \ + shuffleInc(i, 0, mask); \ + shuffleInc(i, 1, mask); \ + shuffleInc(i, 2, mask); \ + shuffleInc(i, 3, mask); \ + shuffleInc(i, 4, mask); \ + shuffleInc(i, 5, mask); \ + shuffleInc(i, 6, mask); \ + shuffleInc(i, 7, mask); \ + +#define reduceMatrix(mask) \ + reduceRow(0, mask); \ + reduceRow(1, mask); \ + reduceRow(2, mask); \ + reduceRow(3, mask); \ + reduceRow(4, mask); \ + reduceRow(5, mask); \ + reduceRow(6, mask); \ + reduceRow(7, mask); \ + + // actually perform the reduction, now each thread of index (_, y, z) + // contains the correct values in its registers that belong in the output + // block + reduceMatrix(1); + reduceMatrix(2); + reduceMatrix(4); + +#undef shuffleInc +#undef reduceRow +#undef reduceMatrix + + // now we need to copy the 64 values into main memory. We can't split work + // among threads because all variables are in registers. There's 2 ways + // to do this: + // (1) have 1 thread do 64 writes from registers into global memory + // (2) have 1 thread do 64 writes into shared memory, and then 8 threads + // each do 8 writes into global memory. We can just overwrite the shared + // memory from the problem we just solved. + // (3) Copies the values into new registers using conditional logic. + +#define makeAssignments(i) \ + val0 = res(i, 0); \ + val1 = res(i, 1); \ + val2 = res(i, 2); \ + val3 = res(i, 3); \ + val4 = res(i, 4); \ + val5 = res(i, 5); \ + val6 = res(i, 6); \ + val7 = res(i, 7); \ + + Scalar val0; + Scalar val1; + Scalar val2; + Scalar val3; + Scalar val4; + Scalar val5; + Scalar val6; + Scalar val7; + + switch (threadIdx.x) { + case 0: + makeAssignments(0); + break; + case 1: + makeAssignments(1); + break; + case 2: + makeAssignments(2); + break; + case 3: + makeAssignments(3); + break; + case 4: + makeAssignments(4); + break; + case 5: + makeAssignments(5); + break; + case 6: + makeAssignments(6); + break; + case 7: + makeAssignments(7); + break; + } + +#undef res + + const Index vert_base = base_m + 4 * threadIdx.y + (threadIdx.x % 4) + 32 * (threadIdx.x / 4); + const Index horiz_base = base_n + 4 * threadIdx.z; + + if (!needs_edge_check || vert_base < m_size) { + if (!needs_edge_check || horiz_base + 35 < n_size) { + output(vert_base, horiz_base + 0) = val0; + output(vert_base, horiz_base + 1) = val1; + output(vert_base, horiz_base + 2) = val2; + output(vert_base, horiz_base + 3) = val3; + output(vert_base, horiz_base + 32) = val4; + output(vert_base, horiz_base + 33) = val5; + output(vert_base, horiz_base + 34) = val6; + output(vert_base, horiz_base + 35) = val7; + } else if (horiz_base + 34 < n_size) { + output(vert_base, horiz_base + 0) = val0; + output(vert_base, horiz_base + 1) = val1; + output(vert_base, horiz_base + 2) = val2; + output(vert_base, horiz_base + 3) = val3; + output(vert_base, horiz_base + 32) = val4; + output(vert_base, horiz_base + 33) = val5; + output(vert_base, horiz_base + 34) = val6; + } else if (horiz_base + 33 < n_size) { + output(vert_base, horiz_base + 0) = val0; + output(vert_base, horiz_base + 1) = val1; + output(vert_base, horiz_base + 2) = val2; + output(vert_base, horiz_base + 3) = val3; + output(vert_base, horiz_base + 32) = val4; + output(vert_base, horiz_base + 33) = val5; + } else if (horiz_base + 32 < n_size) { + output(vert_base, horiz_base + 0) = val0; + output(vert_base, horiz_base + 1) = val1; + output(vert_base, horiz_base + 2) = val2; + output(vert_base, horiz_base + 3) = val3; + output(vert_base, horiz_base + 32) = val4; + } else if (horiz_base + 3 < n_size) { + output(vert_base, horiz_base + 0) = val0; + output(vert_base, horiz_base + 1) = val1; + output(vert_base, horiz_base + 2) = val2; + output(vert_base, horiz_base + 3) = val3; + } else if (horiz_base + 2 < n_size) { + output(vert_base, horiz_base + 0) = val0; + output(vert_base, horiz_base + 1) = val1; + output(vert_base, horiz_base + 2) = val2; + } else if (horiz_base + 1 < n_size) { + output(vert_base, horiz_base + 0) = val0; + output(vert_base, horiz_base + 1) = val1; + } else if (horiz_base < n_size) { + output(vert_base, horiz_base + 0) = val0; + } + } + } + + + template +__global__ void + __launch_bounds__(512) + EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, + const OutputMapper output, + const Index m_size, const Index n_size, const Index k_size) { + __shared__ float4 lhs_shmem[(68 * 64) / 4]; + __shared__ float2 rhs_shmem[((66 * 8 + 24) * 8) / 2]; + + const Index m_block_idx = blockIdx.x; + const Index n_block_idx = blockIdx.y; + + const Index base_m = 64 * m_block_idx; + const Index base_n = 64 * n_block_idx; + + if (base_m + 63 < m_size && base_n + 63 < n_size) { + EigenFloatContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + } else { + EigenFloatContractionKernelInternal(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); + } + } + + + template + struct TensorEvaluator, GpuDevice> : + public TensorContractionEvaluatorBase, GpuDevice> > { + + typedef GpuDevice Device; + + typedef TensorEvaluator, Device> Self; + typedef TensorContractionEvaluatorBase Base; + + typedef TensorContractionOp XprType; + typedef typename internal::remove_const::type Scalar; + typedef typename XprType::Packet Packet; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + typedef array::Dimensions::count> left_dim_mapper_t; + typedef array::Dimensions::count> right_dim_mapper_t; + + typedef array::value> contract_t; + typedef array::Dimensions::count - internal::array_size::value> left_nocontract_t; + typedef array::Dimensions::count - internal::array_size::value> right_nocontract_t; + + static const int NumDims = max_n_1::Dimensions::count + TensorEvaluator::Dimensions::count - 2 * internal::array_size::value>::size; + + typedef DSizes Dimensions; + + // typedefs needed in evalTo + typedef typename internal::remove_const::type LhsScalar; + typedef typename internal::remove_const::type RhsScalar; + + typedef TensorEvaluator LeftEvaluator; + typedef TensorEvaluator RightEvaluator; + + typedef typename LeftEvaluator::Dimensions LeftDimensions; + typedef typename RightEvaluator::Dimensions RightDimensions; + + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : + Base(op, device) {} + + // We need to redefine this method to make nvcc happy + EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { + this->m_leftImpl.evalSubExprsIfNeeded(NULL); + this->m_rightImpl.evalSubExprsIfNeeded(NULL); + if (data) { + evalTo(data); + return false; + } else { + this->m_result = static_cast(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); + evalTo(this->m_result); + return true; + } + } + + void evalTo(Scalar* buffer) const { + if (this->m_lhs_inner_dim_contiguous) { + if (this->m_rhs_inner_dim_contiguous) { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); + } + else { + evalTyped(buffer); + } + } + else { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); + } + else { + evalTyped(buffer); + } + } + } + else { + if (this->m_rhs_inner_dim_contiguous) { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); + } + else { + evalTyped(buffer); + } + } + else { + if (this->m_rhs_inner_dim_reordered) { + evalTyped(buffer); + } + else { + evalTyped(buffer); + } + } + } + } + + template + void evalTyped(Scalar* buffer) const { + // columns in left side, rows in right side + const Index k = this->m_k_size; + + // rows in left side + const Index m = this->m_i_size; + + // columns in right side + const Index n = this->m_j_size; + + // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) + this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); + + typedef internal::TensorContractionInputMapper LhsMapper; + + typedef internal::TensorContractionInputMapper RhsMapper; + + typedef internal::blas_data_mapper OutputMapper; + + + // initialize data mappers + LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, + this->m_left_contracting_strides, this->m_k_strides); + + RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, + this->m_right_contracting_strides, this->m_k_strides); + + OutputMapper output(buffer, m); + + const Index m_blocks = (m + 63) / 64; + const Index n_blocks = (n + 63) / 64; + const dim3 num_blocks(m_blocks, n_blocks, 1); + const dim3 block_size(8, 8, 8); + + cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte); + if (internal::is_same::value && + internal::is_same::value) { + EigenFloatContractionKernel + <<m_device.stream()>>>(lhs, rhs, output, m, n, k); + } else { + EigenContractionKernel + <<m_device.stream()>>>(lhs, rhs, output, m, n, k); + } + + assert(cudaGetLastError() == cudaSuccess); + } + }; + +} // end namespace Eigen + +#endif // EIGEN_USE_GPU and __CUDACC__ + +#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H -- cgit v1.2.3