aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/cuda_kernel_helper.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/cuda_kernel_helper.h')
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
new file mode 100644
index 0000000000..09304af13c
--- /dev/null
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -0,0 +1,52 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
+
+#if GOOGLE_CUDA
+
+#include <algorithm>
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+ i += blockDim.x * gridDim.x)
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+struct CudaLaunchConfig {
+ // Logical number of thread that works on the elements. If each logic thread
+ // works on exactly a single element, this is the same as the working element
+ // count.
+ int virtual_thread_count = -1;
+ // Number of threads per block.
+ int thread_per_block = -1;
+ // Number of blocks for Cuda kernel launch.
+ int block_count = -1;
+};
+
+// Calculate the Cuda launch config we should use for a kernel launch.
+// This is assuming the kernel is quite simple and will largely be
+// memory-limited.
+inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
+ const GPUDevice& d) {
+ const int virtual_thread_count = work_element_count;
+ const int physical_thread_count = std::min(
+ d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
+ virtual_thread_count);
+ const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
+ const int block_count = std::min(
+ (physical_thread_count + thread_per_block - 1) / thread_per_block,
+ d.getNumCudaMultiProcessors());
+
+ CudaLaunchConfig config;
+ config.virtual_thread_count = virtual_thread_count;
+ config.thread_per_block = thread_per_block;
+ config.block_count = block_count;
+ return config;
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_