aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h16
1 files changed, 14 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index 0ffe68ab3..24a57970a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -201,7 +201,7 @@ class TensorExecutor<Expression, GpuDevice, Vectorizable> {
};
-#if defined(EIGEN_CUDACC)
+#if defined(EIGEN_CUDACC) || defined(EIGEN_HIPCC)
template <typename Evaluator, typename Index, bool Vectorizable>
struct EigenMetaKernelEval {
static __device__ EIGEN_ALWAYS_INLINE
@@ -250,6 +250,17 @@ inline void TensorExecutor<Expression, GpuDevice, Vectorizable>::run(
TensorEvaluator<Expression, GpuDevice> evaluator(expr, device);
const bool needs_assign = evaluator.evalSubExprsIfNeeded(NULL);
if (needs_assign) {
+#if defined(EIGEN_HIPCC)
+ const int block_size = device.maxHipThreadsPerBlock();
+ const int max_blocks = device.getNumHipMultiProcessors() *
+ device.maxHipThreadsPerMultiProcessor() / block_size;
+ const Index size = array_prod(evaluator.dimensions());
+ // Create a least one block to ensure we won't crash when tensorflow calls with tensors of size 0.
+ const int num_blocks = numext::maxi<int>(numext::mini<int>(max_blocks, divup<int>(size, block_size)), 1);
+
+ hipLaunchKernelGGL(HIP_KERNEL_NAME(EigenMetaKernel<TensorEvaluator<Expression, GpuDevice>, Index>),
+ dim3(num_blocks), dim3(block_size), 0, device.stream(), evaluator, size);
+#else
const int block_size = device.maxCudaThreadsPerBlock();
const int max_blocks = device.getNumCudaMultiProcessors() *
device.maxCudaThreadsPerMultiProcessor() / block_size;
@@ -260,11 +271,12 @@ inline void TensorExecutor<Expression, GpuDevice, Vectorizable>::run(
LAUNCH_CUDA_KERNEL(
(EigenMetaKernel<TensorEvaluator<Expression, GpuDevice>, Index>),
num_blocks, block_size, 0, device, evaluator, size);
+#endif
}
evaluator.cleanup();
}
-#endif // EIGEN_CUDACC
+#endif // EIGEN_CUDACC || EIGEN_HIPCC
#endif // EIGEN_USE_GPU
// SYCL Executor policy