diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 0ffe68ab3..24a57970a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -201,7 +201,7 @@ class TensorExecutor<Expression, GpuDevice, Vectorizable> { }; -#if defined(EIGEN_CUDACC) +#if defined(EIGEN_CUDACC) || defined(EIGEN_HIPCC) template <typename Evaluator, typename Index, bool Vectorizable> struct EigenMetaKernelEval { static __device__ EIGEN_ALWAYS_INLINE @@ -250,6 +250,17 @@ inline void TensorExecutor<Expression, GpuDevice, Vectorizable>::run( TensorEvaluator<Expression, GpuDevice> evaluator(expr, device); const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL); if (needs_assign) { +#if defined(EIGEN_HIPCC) + const int block_size = device.maxHipThreadsPerBlock(); + const int max_blocks = device.getNumHipMultiProcessors() * + device.maxHipThreadsPerMultiProcessor() / block_size; + const Index size = array_prod(evaluator.dimensions()); + // Create a least one block to ensure we won't crash when tensorflow calls with tensors of size 0. + const int num_blocks = numext::maxi<int>(numext::mini<int>(max_blocks, divup<int>(size, block_size)), 1); + + hipLaunchKernelGGL(HIP_KERNEL_NAME(EigenMetaKernel<TensorEvaluator<Expression, GpuDevice>, Index>), + dim3(num_blocks), dim3(block_size), 0, device.stream(), evaluator, size); +#else const int block_size = device.maxCudaThreadsPerBlock(); const int max_blocks = device.getNumCudaMultiProcessors() * device.maxCudaThreadsPerMultiProcessor() / block_size; @@ -260,11 +271,12 @@ inline void TensorExecutor<Expression, GpuDevice, Vectorizable>::run( LAUNCH_CUDA_KERNEL( (EigenMetaKernel<TensorEvaluator<Expression, GpuDevice>, Index>), num_blocks, block_size, 0, device, evaluator, size); +#endif } evaluator.cleanup(); } -#endif // EIGEN_CUDACC +#endif // EIGEN_CUDACC || EIGEN_HIPCC #endif // EIGEN_USE_GPU // SYCL Executor policy |