aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
diff options
context:
space:
mode:
authorGravatar Deven Desai <deven.desai.amd@gmail.com>2018-06-20 16:44:58 -0400
committerGravatar Deven Desai <deven.desai.amd@gmail.com>2018-06-20 16:44:58 -0400
commit1bb6fa99a31d2dcf5431087d3f238e2dcca03084 (patch)
treee62d41b8d6430849aea4bf97785a54488bf542d4 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
parentcfdabbcc8f708c06da2bfa4e924edc25619f013a (diff)
merging the CUDA and HIP implementation for the Tensor directory and the unit tests
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h189
1 files changed, 175 insertions, 14 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
index 903bc51cc..238754424 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h
@@ -9,10 +9,10 @@
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
-#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
-#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
+#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
+#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
-#if defined(EIGEN_USE_GPU) && defined(EIGEN_CUDACC)
+#if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
namespace Eigen {
@@ -388,7 +388,7 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
// the sum across all big k blocks of the product of little k block of index (x, y)
// with block of index (y, z). To compute the final output, we need to reduce
// the 8 threads over y by summation.
-#if defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000
+#if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000)
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
#else
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
@@ -503,7 +503,11 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
template<typename Scalar, typename Index, typename LhsMapper,
typename RhsMapper, typename OutputMapper>
__global__ void
+#if defined(EIGEN_HIPCC)
+__launch_bounds__(512, 1)
+#else
__launch_bounds__(512)
+#endif
EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output,
const Index m_size, const Index n_size, const Index k_size) {
@@ -542,10 +546,48 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
results[i].x = results[i].y = results[i].z = results[i].w = 0;
}
+#if defined(EIGEN_HIPCC)
#define prefetch_lhs(reg, row, col) \
if (!CHECK_LHS_BOUNDARY) { \
if (col < k_size) { \
+ reg.x =lhs(row + 0, col); \
+ reg.y =lhs(row + 1, col); \
+ reg.z =lhs(row + 2, col); \
+ reg.w =lhs(row + 3, col); \
+ } \
+ } else { \
+ if (col < k_size) { \
+ if (row + 3 < m_size) { \
+ reg.x =lhs(row + 0, col); \
+ reg.y =lhs(row + 1, col); \
+ reg.z =lhs(row + 2, col); \
+ reg.w =lhs(row + 3, col); \
+ } else if (row + 2 < m_size) { \
+ reg.x =lhs(row + 0, col); \
+ reg.y =lhs(row + 1, col); \
+ reg.z =lhs(row + 2, col); \
+ } else if (row + 1 < m_size) { \
+ reg.x =lhs(row + 0, col); \
+ reg.y =lhs(row + 1, col); \
+ } else if (row < m_size) { \
+ reg.x =lhs(row + 0, col); \
+ } \
+ } \
+ } \
+
+#define prefetch_rhs_hipcc(reg, row, col) \
+ reg.x =rhs(row + 0, col); \
+ reg.y =rhs(row + 1, col); \
+ reg.z =rhs(row + 2, col); \
+ reg.w =rhs(row + 3, col); \
+
+
+#else
+
+#define prefetch_lhs(reg, row, col) \
+ if (!CHECK_LHS_BOUNDARY) { \
+ if (col < k_size) { \
reg =lhs.template loadPacket<Unaligned>(row, col); \
} \
} else { \
@@ -563,14 +605,21 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
reg.x =lhs(row + 0, col); \
} \
} \
- } \
+ } \
+#endif
Index lhs_vert = base_m+threadIdx.x*4;
for (Index k = 0; k < k_size; k += 16) {
+
+#if defined(EIGEN_HIPCC)
+ lhs_pf0 = make_float4(0, 0, 0, 0);
+ rhs_pf0 = make_float4(0, 0, 0, 0);
+#else
lhs_pf0 = internal::pset1<float4>(0);
rhs_pf0 = internal::pset1<float4>(0);
+#endif
Index lhs_horiz = threadIdx.y+k;
prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
@@ -581,7 +630,11 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
+#if defined(EIGEN_HIPCC)
+ prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0)
+#else
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
+#endif
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@@ -596,7 +649,11 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
} else {
if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0)
+#else
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
+#endif
} else if ((rhs_vert + 2) < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
@@ -618,7 +675,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
x1 = rhs_pf0.x;
x2 = rhs_pf0.z;
}
- #if defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000
+ #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC_VER) && EIGEN_CUDACC_VER < 90000)
x1 = __shfl_xor(x1, 4);
x2 = __shfl_xor(x2, 4);
#else
@@ -695,7 +752,11 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
#undef prefetch_lhs
#undef add_vals
-
+
+#if defined(EIGEN_HIPCC)
+#undef prefetch_rhs_hipcc
+#endif
+
Index horiz_base = threadIdx.y*4+base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
for (int i = 0; i < 4; i++) {
@@ -784,9 +845,33 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
results[i].x = results[i].y = results[i].z = results[i].w = 0;
}
+#if defined(EIGEN_HIPCC)
+
+#define prefetch_lhs_hipcc(reg, row, col) \
+ reg.x =lhs(row + 0, col); \
+ reg.y =lhs(row + 1, col); \
+ reg.z =lhs(row + 2, col); \
+ reg.w =lhs(row + 3, col);
+
+#define prefetch_rhs_hipcc(reg, row, col) \
+ reg.x =rhs(row + 0, col); \
+ reg.y =rhs(row + 1, col); \
+ reg.z =rhs(row + 2, col); \
+ reg.w =rhs(row + 3, col);
+
+#endif
Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
for (Index k = 0; k < k_size; k += 32) {
+#if defined(EIGEN_HIPCC)
+ lhs_pf0 = make_float4(0, 0, 0, 0);
+ lhs_pf1 = make_float4(0, 0, 0, 0);
+ lhs_pf2 = make_float4(0, 0, 0, 0);
+ lhs_pf3 = make_float4(0, 0, 0, 0);
+
+ rhs_pf0 = make_float4(0, 0, 0, 0);
+ rhs_pf1 = make_float4(0, 0, 0, 0);
+#else
lhs_pf0 = internal::pset1<float4>(0);
lhs_pf1 = internal::pset1<float4>(0);
lhs_pf2 = internal::pset1<float4>(0);
@@ -794,40 +879,85 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
rhs_pf0 = internal::pset1<float4>(0);
rhs_pf1 = internal::pset1<float4>(0);
+#endif
if (!CHECK_LHS_BOUNDARY) {
if ((threadIdx.y/4+k+24) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+ prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8))
+ prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16))
+ prefetch_lhs_hipcc(lhs_pf3, lhs_vert, (threadIdx.y/4+k+24))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
+#endif
} else if ((threadIdx.y/4+k+16) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+ prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8))
+ prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
+#endif
} else if ((threadIdx.y/4+k+8) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+ prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
+#endif
} else if ((threadIdx.y/4+k) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
+#endif
}
} else {
// just CHECK_LHS_BOUNDARY
if (lhs_vert + 3 < m_size) {
if ((threadIdx.y/4+k+24) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+ prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8))
+ prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16))
+ prefetch_lhs_hipcc(lhs_pf3, lhs_vert, (threadIdx.y/4+k+24))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
+#endif
} else if ((threadIdx.y/4+k+16) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+ prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8))
+ prefetch_lhs_hipcc(lhs_pf2, lhs_vert, (threadIdx.y/4+k+16))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
+#endif
} else if ((threadIdx.y/4+k+8) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+ prefetch_lhs_hipcc(lhs_pf1, lhs_vert, (threadIdx.y/4+k+8))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
+#endif
} else if ((threadIdx.y/4+k) < k_size) {
+#if defined(EIGEN_HIPCC)
+ prefetch_lhs_hipcc(lhs_pf0, lhs_vert, (threadIdx.y/4+k))
+#else
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
+#endif
}
} else if (lhs_vert + 2 < m_size) {
if ((threadIdx.y/4+k+24) < k_size) {
@@ -916,8 +1046,13 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
+#if defined(EIGEN_HIPCC)
+ prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0)
+ prefetch_rhs_hipcc(rhs_pf1, rhs_vert, rhs_horiz1)
+#else
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
+#endif
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@@ -939,8 +1074,13 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (rhs_horiz1 < n_size) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
+#if defined(EIGEN_HIPCC)
+ prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0)
+ prefetch_rhs_hipcc(rhs_pf1, rhs_vert, rhs_horiz1)
+#else
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
+#endif
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@@ -961,7 +1101,11 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
} else if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
+#if defined(EIGEN_HIPCC)
+ prefetch_rhs_hipcc(rhs_pf0, rhs_vert, rhs_horiz0)
+#else
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
+#endif
} else if ((rhs_vert + 2) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@@ -1069,7 +1213,11 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
__syncthreads();
} // end loop over k
-
+#if defined(EIGEN_HIPCC)
+#undef prefetch_lhs_hipcc
+#undef prefetch_rhs_hipcc
+#endif
+
__syncthreads();
Index horiz_base = (threadIdx.y/4)*8+base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
@@ -1134,7 +1282,11 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
template<typename Index, typename LhsMapper,
typename RhsMapper, typename OutputMapper>
__global__ void
+#if defined(EIGEN_HIPCC)
+__launch_bounds__(256, 1)
+#else
__launch_bounds__(256)
+#endif
EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output,
const Index m_size, const Index n_size, const Index k_size) {
@@ -1177,7 +1329,11 @@ EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
template<typename Index, typename LhsMapper,
typename RhsMapper, typename OutputMapper>
__global__ void
+#if defined(EIGEN_HIPCC)
+__launch_bounds__(256, 1)
+#else
__launch_bounds__(256)
+#endif
EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output,
const Index m_size, const Index n_size, const Index k_size) {
@@ -1323,7 +1479,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index n_blocks = (n + 63) / 64;
const dim3 num_blocks(m_blocks, n_blocks, 1);
const dim3 block_size(8, 8, 8);
- LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
+ LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
}
};
@@ -1334,13 +1490,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index n_blocks = (n + 63) / 64;
const dim3 num_blocks(m_blocks, n_blocks, 1);
const dim3 block_size(16, 16, 1);
- LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
+ LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
} else {
const Index m_blocks = (m + 127) / 128;
const Index n_blocks = (n + 63) / 64;
const dim3 num_blocks(m_blocks, n_blocks, 1);
const dim3 block_size(8, 32, 1);
- LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
+ LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
}
}
};
@@ -1384,12 +1540,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
OutputMapper output(buffer, m);
- setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte);
+#if defined(EIGEN_USE_HIP)
+ setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
+#else
+ setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
+#endif
+
LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device);
}
};
} // end namespace Eigen
-#endif // EIGEN_USE_GPU and EIGEN_CUDACC
-#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
+#endif // EIGEN_USE_GPU and EIGEN_GPUCC
+#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H