From 1bb6fa99a31d2dcf5431087d3f238e2dcca03084 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 20 Jun 2018 16:44:58 -0400 Subject: merging the CUDA and HIP implementation for the Tensor directory and the unit tests --- .../Eigen/CXX11/src/Tensor/TensorContractionGpu.h | 189 +++++++++++++++++++-- 1 file changed, 175 insertions(+), 14 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h index 903bc51cc..238754424 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h @@ -9,10 +9,10 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H -#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H +#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H -#if defined(EIGEN_USE_GPU) && defined(EIGEN_CUDACC) +#if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC) namespace Eigen { @@ -388,7 +388,7 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, // the sum across all big k blocks of the product of little k block of index (x, y) // with block of index (y, z). To compute the final output, we need to reduce // the 8 threads over y by summation. -#if defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000 +#if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000) #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) #else #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask) @@ -503,7 +503,11 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, template __global__ void +#if defined(EIGEN_HIPCC) +__launch_bounds__(512, 1) +#else __launch_bounds__(512) +#endif EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size, const Index n_size, const Index k_size) { @@ -542,7 +546,45 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh results[i].x = results[i].y = results[i].z = results[i].w = 0; } +#if defined(EIGEN_HIPCC) +#define prefetch_lhs(reg, row, col) \ + if (!CHECK_LHS_BOUNDARY) { \ + if (col < k_size) { \ + reg.x =lhs(row + 0, col); \ + reg.y =lhs(row + 1, col); \ + reg.z =lhs(row + 2, col); \ + reg.w =lhs(row + 3, col); \ + } \ + } else { \ + if (col < k_size) { \ + if (row + 3 < m_size) { \ + reg.x =lhs(row + 0, col); \ + reg.y =lhs(row + 1, col); \ + reg.z =lhs(row + 2, col); \ + reg.w =lhs(row + 3, col); \ + } else if (row + 2 < m_size) { \ + reg.x =lhs(row + 0, col); \ + reg.y =lhs(row + 1, col); \ + reg.z =lhs(row + 2, col); \ + } else if (row + 1 < m_size) { \ + reg.x =lhs(row + 0, col); \ + reg.y =lhs(row + 1, col); \ + } else if (row < m_size) { \ + reg.x =lhs(row + 0, col); \ + } \ + } \ + } \ + +#define prefetch_rhs_hipcc(reg, row, col) \ + reg.x =rhs(row + 0, col); \ + reg.y =rhs(row + 1, col); \ + reg.z =rhs(row + 2, col); \ + reg.w =rhs(row + 3, col); \ + + +#else + #define prefetch_lhs(reg, row, col) \ if (!CHECK_LHS_BOUNDARY) { \ if (col < k_size) { \ @@ -563,14 +605,21 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh reg.x =lhs(row + 0, col); \ } \ } \ - } \ + } \ +#endif Index lhs_vert = base_m+threadIdx.x*4; for (Index k = 0; k < k_size; k += 16) { + +#if defined(EIGEN_HIPCC) + lhs_pf0 = make_float4(0, 0, 0, 0); + rhs_pf0 = make_float4(0, 0, 0, 0); +#else lhs_pf0 = internal::pset1(0); rhs_pf0 = internal::pset1(0); +#endif Index lhs_horiz = threadIdx.y+k; prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz) @@ -581,7 +630,11 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh if (!CHECK_RHS_BOUNDARY) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY +#if defined(EIGEN_HIPCC) + prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) +#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); +#endif } else if (rhs_vert + 2 < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -596,7 +649,11 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh } else { if (rhs_horiz0 < n_size) { if ((rhs_vert + 3) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) +#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); +#endif } else if ((rhs_vert + 2) < k_size) { rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); @@ -618,7 +675,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh x1 = rhs_pf0.x; x2 = rhs_pf0.z; } - #if defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000 + #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000) x1 = __shfl_xor(x1, 4); x2 = __shfl_xor(x2, 4); #else @@ -695,7 +752,11 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh #undef prefetch_lhs #undef add_vals - + +#if defined(EIGEN_HIPCC) +#undef prefetch_rhs_hipcc +#endif + Index horiz_base = threadIdx.y*4+base_n; if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { for (int i = 0; i < 4; i++) { @@ -784,9 +845,33 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, results[i].x = results[i].y = results[i].z = results[i].w = 0; } +#if defined(EIGEN_HIPCC) + +#define prefetch_lhs_hipcc(reg, row, col) \ + reg.x =lhs(row + 0, col); \ + reg.y =lhs(row + 1, col); \ + reg.z =lhs(row + 2, col); \ + reg.w =lhs(row + 3, col); + +#define prefetch_rhs_hipcc(reg, row, col) \ + reg.x =rhs(row + 0, col); \ + reg.y =rhs(row + 1, col); \ + reg.z =rhs(row + 2, col); \ + reg.w =rhs(row + 3, col); + +#endif Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32; for (Index k = 0; k < k_size; k += 32) { +#if defined(EIGEN_HIPCC) + lhs_pf0 = make_float4(0, 0, 0, 0); + lhs_pf1 = make_float4(0, 0, 0, 0); + lhs_pf2 = make_float4(0, 0, 0, 0); + lhs_pf3 = make_float4(0, 0, 0, 0); + + rhs_pf0 = make_float4(0, 0, 0, 0); + rhs_pf1 = make_float4(0, 0, 0, 0); +#else lhs_pf0 = internal::pset1(0); lhs_pf1 = internal::pset1(0); lhs_pf2 = internal::pset1(0); @@ -794,40 +879,85 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, rhs_pf0 = internal::pset1(0); rhs_pf1 = internal::pset1(0); +#endif if (!CHECK_LHS_BOUNDARY) { if ((threadIdx.y/4+k+24) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) + prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) + prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) + prefetch_lhs_hipcc(lhs_pf3, lhs_vert, (threadIdx.y/4+k+24)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); lhs_pf3 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+24)); +#endif } else if ((threadIdx.y/4+k+16) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) + prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) + prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); +#endif } else if ((threadIdx.y/4+k+8) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) + prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); +#endif } else if ((threadIdx.y/4+k) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); +#endif } } else { // just CHECK_LHS_BOUNDARY if (lhs_vert + 3 < m_size) { if ((threadIdx.y/4+k+24) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) + prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) + prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) + prefetch_lhs_hipcc(lhs_pf3, lhs_vert, (threadIdx.y/4+k+24)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); lhs_pf3 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+24)); +#endif } else if ((threadIdx.y/4+k+16) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) + prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) + prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); +#endif } else if ((threadIdx.y/4+k+8) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) + prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); +#endif } else if ((threadIdx.y/4+k) < k_size) { +#if defined(EIGEN_HIPCC) + prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) +#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); +#endif } } else if (lhs_vert + 2 < m_size) { if ((threadIdx.y/4+k+24) < k_size) { @@ -916,8 +1046,13 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, if (!CHECK_RHS_BOUNDARY) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY +#if defined(EIGEN_HIPCC) + prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) + prefetch_rhs_hipcc(rhs_pf1, rhs_vert, rhs_horiz1) +#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); rhs_pf1 = rhs.template loadPacket(rhs_vert, rhs_horiz1); +#endif } else if (rhs_vert + 2 < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -939,8 +1074,13 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, if (rhs_horiz1 < n_size) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY +#if defined(EIGEN_HIPCC) + prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) + prefetch_rhs_hipcc(rhs_pf1, rhs_vert, rhs_horiz1) +#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); rhs_pf1 = rhs.template loadPacket(rhs_vert, rhs_horiz1); +#endif } else if (rhs_vert + 2 < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -961,7 +1101,11 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, } else if (rhs_horiz0 < n_size) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY +#if defined(EIGEN_HIPCC) + prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) +#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); +#endif } else if ((rhs_vert + 2) < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -1069,7 +1213,11 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, __syncthreads(); } // end loop over k - +#if defined(EIGEN_HIPCC) +#undef prefetch_lhs_hipcc +#undef prefetch_rhs_hipcc +#endif + __syncthreads(); Index horiz_base = (threadIdx.y/4)*8+base_n; if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { @@ -1134,7 +1282,11 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, template __global__ void +#if defined(EIGEN_HIPCC) +__launch_bounds__(256, 1) +#else __launch_bounds__(256) +#endif EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size, const Index n_size, const Index k_size) { @@ -1177,7 +1329,11 @@ EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, template __global__ void +#if defined(EIGEN_HIPCC) +__launch_bounds__(256, 1) +#else __launch_bounds__(256) +#endif EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size, const Index n_size, const Index k_size) { @@ -1323,7 +1479,7 @@ struct TensorEvaluator), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); + LAUNCH_GPU_KERNEL((EigenContractionKernel), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); } }; @@ -1334,13 +1490,13 @@ struct TensorEvaluator), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); + LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); } else { const Index m_blocks = (m + 127) / 128; const Index n_blocks = (n + 63) / 64; const dim3 num_blocks(m_blocks, n_blocks, 1); const dim3 block_size(8, 32, 1); - LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); + LAUNCH_GPU_KERNEL((EigenFloatContractionKernel), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); } } }; @@ -1384,12 +1540,17 @@ struct TensorEvaluator::Run(lhs, rhs, output, m, n, k, this->m_device); } }; } // end namespace Eigen -#endif // EIGEN_USE_GPU and EIGEN_CUDACC -#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H +#endif // EIGEN_USE_GPU and EIGEN_GPUCC +#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H -- cgit v1.2.3