diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 21 |
1 files changed, 5 insertions, 16 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 8bbe449cc..1181c2753 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -250,28 +250,17 @@ inline void TensorExecutor<Expression, GpuDevice, Vectorizable>::run( TensorEvaluator<Expression, GpuDevice> evaluator(expr, device); const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL); if (needs_assign) { -#if defined(EIGEN_HIPCC) - const int block_size = device.maxHipThreadsPerBlock(); - const int max_blocks = device.getNumHipMultiProcessors() * - device.maxHipThreadsPerMultiProcessor() / block_size; - const Index size = array_prod(evaluator.dimensions()); - // Create a least one block to ensure we won't crash when tensorflow calls with tensors of size 0. - const int num_blocks = numext::maxi<int>(numext::mini<int>(max_blocks, divup<int>(size, block_size)), 1); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(EigenMetaKernel<TensorEvaluator<Expression, GpuDevice>, Index>), - dim3(num_blocks), dim3(block_size), 0, device.stream(), evaluator, size); -#else - const int block_size = device.maxCudaThreadsPerBlock(); - const int max_blocks = device.getNumCudaMultiProcessors() * - device.maxCudaThreadsPerMultiProcessor() / block_size; + + const int block_size = device.maxGpuThreadsPerBlock(); + const int max_blocks = device.getNumGpuMultiProcessors() * + device.maxGpuThreadsPerMultiProcessor() / block_size; const Index size = array_prod(evaluator.dimensions()); // Create a least one block to ensure we won't crash when tensorflow calls with tensors of size 0. const int num_blocks = numext::maxi<int>(numext::mini<int>(max_blocks, divup<int>(size, block_size)), 1); - LAUNCH_CUDA_KERNEL( + LAUNCH_GPU_KERNEL( (EigenMetaKernel<TensorEvaluator<Expression, GpuDevice>, Index>), num_blocks, block_size, 0, device, evaluator, size); -#endif } evaluator.cleanup(); } |