diff options
Diffstat (limited to 'tensorflow/core/util/cuda_kernel_helper.h')
-rw-r--r-- | tensorflow/core/util/cuda_kernel_helper.h | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h new file mode 100644 index 0000000000..09304af13c --- /dev/null +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -0,0 +1,52 @@ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ + +#if GOOGLE_CUDA + +#include <algorithm> + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +struct CudaLaunchConfig { + // Logical number of thread that works on the elements. If each logic thread + // works on exactly a single element, this is the same as the working element + // count. + int virtual_thread_count = -1; + // Number of threads per block. + int thread_per_block = -1; + // Number of blocks for Cuda kernel launch. + int block_count = -1; +}; + +// Calculate the Cuda launch config we should use for a kernel launch. +// This is assuming the kernel is quite simple and will largely be +// memory-limited. +inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, + const GPUDevice& d) { + const int virtual_thread_count = work_element_count; + const int physical_thread_count = std::min( + d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(), + virtual_thread_count); + const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock()); + const int block_count = std::min( + (physical_thread_count + thread_per_block - 1) / thread_per_block, + d.getNumCudaMultiProcessors()); + + CudaLaunchConfig config; + config.virtual_thread_count = virtual_thread_count; + config.thread_per_block = thread_per_block; + config.block_count = block_count; + return config; +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ |