From 876f392c396318f33454168db36ed54308e54e0d Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Wed, 11 Jul 2018 10:39:54 -0400 Subject: Updates corresponding to the latest round of PR feedback The major changes are 1. Moving CUDA/PacketMath.h to GPU/PacketMath.h 2. Moving CUDA/MathFunctions.h to GPU/MathFunction.h 3. Moving CUDA/CudaSpecialFunctions.h to GPU/GpuSpecialFunctions.h The above three changes effectively enable the Eigen "Packet" layer for the HIP platform 4. Merging the "hip_basic" and "cuda_basic" unit tests into one ("gpu_basic") 5. Updating the "EIGEN_DEVICE_FUNC" marking in some places The change has been tested on the HIP and CUDA platforms. --- .../Eigen/CXX11/src/Tensor/TensorContractionGpu.h | 147 --------------------- 1 file changed, 147 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h index 238754424..a4f92ee44 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h @@ -546,45 +546,6 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh results[i].x = results[i].y = results[i].z = results[i].w = 0; } -#if defined(EIGEN_HIPCC) - -#define prefetch_lhs(reg, row, col) \ - if (!CHECK_LHS_BOUNDARY) { \ - if (col < k_size) { \ - reg.x =lhs(row + 0, col); \ - reg.y =lhs(row + 1, col); \ - reg.z =lhs(row + 2, col); \ - reg.w =lhs(row + 3, col); \ - } \ - } else { \ - if (col < k_size) { \ - if (row + 3 < m_size) { \ - reg.x =lhs(row + 0, col); \ - reg.y =lhs(row + 1, col); \ - reg.z =lhs(row + 2, col); \ - reg.w =lhs(row + 3, col); \ - } else if (row + 2 < m_size) { \ - reg.x =lhs(row + 0, col); \ - reg.y =lhs(row + 1, col); \ - reg.z =lhs(row + 2, col); \ - } else if (row + 1 < m_size) { \ - reg.x =lhs(row + 0, col); \ - reg.y =lhs(row + 1, col); \ - } else if (row < m_size) { \ - reg.x =lhs(row + 0, col); \ - } \ - } \ - } \ - -#define prefetch_rhs_hipcc(reg, row, col) \ - reg.x =rhs(row + 0, col); \ - reg.y =rhs(row + 1, col); \ - reg.z =rhs(row + 2, col); \ - reg.w =rhs(row + 3, col); \ - - -#else - #define prefetch_lhs(reg, row, col) \ if (!CHECK_LHS_BOUNDARY) { \ if (col < k_size) { \ @@ -607,19 +568,12 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh } \ } \ -#endif - Index lhs_vert = base_m+threadIdx.x*4; for (Index k = 0; k < k_size; k += 16) { -#if defined(EIGEN_HIPCC) - lhs_pf0 = make_float4(0, 0, 0, 0); - rhs_pf0 = make_float4(0, 0, 0, 0); -#else lhs_pf0 = internal::pset1(0); rhs_pf0 = internal::pset1(0); -#endif Index lhs_horiz = threadIdx.y+k; prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz) @@ -630,11 +584,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh if (!CHECK_RHS_BOUNDARY) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY -#if defined(EIGEN_HIPCC) - prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) -#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); -#endif } else if (rhs_vert + 2 < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -649,11 +599,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh } else { if (rhs_horiz0 < n_size) { if ((rhs_vert + 3) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) -#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); -#endif } else if ((rhs_vert + 2) < k_size) { rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); @@ -753,10 +699,6 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh #undef prefetch_lhs #undef add_vals -#if defined(EIGEN_HIPCC) -#undef prefetch_rhs_hipcc -#endif - Index horiz_base = threadIdx.y*4+base_n; if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { for (int i = 0; i < 4; i++) { @@ -845,33 +787,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, results[i].x = results[i].y = results[i].z = results[i].w = 0; } -#if defined(EIGEN_HIPCC) - -#define prefetch_lhs_hipcc(reg, row, col) \ - reg.x =lhs(row + 0, col); \ - reg.y =lhs(row + 1, col); \ - reg.z =lhs(row + 2, col); \ - reg.w =lhs(row + 3, col); - -#define prefetch_rhs_hipcc(reg, row, col) \ - reg.x =rhs(row + 0, col); \ - reg.y =rhs(row + 1, col); \ - reg.z =rhs(row + 2, col); \ - reg.w =rhs(row + 3, col); - -#endif - Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32; for (Index k = 0; k < k_size; k += 32) { -#if defined(EIGEN_HIPCC) - lhs_pf0 = make_float4(0, 0, 0, 0); - lhs_pf1 = make_float4(0, 0, 0, 0); - lhs_pf2 = make_float4(0, 0, 0, 0); - lhs_pf3 = make_float4(0, 0, 0, 0); - - rhs_pf0 = make_float4(0, 0, 0, 0); - rhs_pf1 = make_float4(0, 0, 0, 0); -#else lhs_pf0 = internal::pset1(0); lhs_pf1 = internal::pset1(0); lhs_pf2 = internal::pset1(0); @@ -879,85 +796,40 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, rhs_pf0 = internal::pset1(0); rhs_pf1 = internal::pset1(0); -#endif if (!CHECK_LHS_BOUNDARY) { if ((threadIdx.y/4+k+24) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) - prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) - prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) - prefetch_lhs_hipcc(lhs_pf3, lhs_vert, (threadIdx.y/4+k+24)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); lhs_pf3 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+24)); -#endif } else if ((threadIdx.y/4+k+16) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) - prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) - prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); -#endif } else if ((threadIdx.y/4+k+8) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) - prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); -#endif } else if ((threadIdx.y/4+k) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); -#endif } } else { // just CHECK_LHS_BOUNDARY if (lhs_vert + 3 < m_size) { if ((threadIdx.y/4+k+24) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) - prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) - prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) - prefetch_lhs_hipcc(lhs_pf3, lhs_vert, (threadIdx.y/4+k+24)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); lhs_pf3 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+24)); -#endif } else if ((threadIdx.y/4+k+16) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) - prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) - prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); lhs_pf2 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+16)); -#endif } else if ((threadIdx.y/4+k+8) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) - prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); lhs_pf1 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k+8)); -#endif } else if ((threadIdx.y/4+k) < k_size) { -#if defined(EIGEN_HIPCC) - prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k)) -#else lhs_pf0 =lhs.template loadPacket(lhs_vert, (threadIdx.y/4+k)); -#endif } } else if (lhs_vert + 2 < m_size) { if ((threadIdx.y/4+k+24) < k_size) { @@ -1046,13 +918,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, if (!CHECK_RHS_BOUNDARY) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY -#if defined(EIGEN_HIPCC) - prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) - prefetch_rhs_hipcc(rhs_pf1, rhs_vert, rhs_horiz1) -#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); rhs_pf1 = rhs.template loadPacket(rhs_vert, rhs_horiz1); -#endif } else if (rhs_vert + 2 < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -1074,13 +941,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, if (rhs_horiz1 < n_size) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY -#if defined(EIGEN_HIPCC) - prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) - prefetch_rhs_hipcc(rhs_pf1, rhs_vert, rhs_horiz1) -#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); rhs_pf1 = rhs.template loadPacket(rhs_vert, rhs_horiz1); -#endif } else if (rhs_vert + 2 < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -1101,11 +963,7 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, } else if (rhs_horiz0 < n_size) { if ((rhs_vert + 3) < k_size) { // just CHECK_RHS_BOUNDARY -#if defined(EIGEN_HIPCC) - prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0) -#else rhs_pf0 = rhs.template loadPacket(rhs_vert, rhs_horiz0); -#endif } else if ((rhs_vert + 2) < k_size) { // just CHECK_RHS_BOUNDARY rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); @@ -1213,11 +1071,6 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, __syncthreads(); } // end loop over k -#if defined(EIGEN_HIPCC) -#undef prefetch_lhs_hipcc -#undef prefetch_rhs_hipcc -#endif - __syncthreads(); Index horiz_base = (threadIdx.y/4)*8+base_n; if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { -- cgit v1.2.3