aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-04 05:09:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 05:13:51 -0800
commit540e86701e077b6b537ee839be296dc3a6cd167a (patch)
tree8a458fc71cd890220d0b31ce3b071294a061089b
parentc02cfb040d2609d605b909b81f4419e948e1560d (diff)
Wrappers for CUDA 9 warp-synchronous intrinsics.
PiperOrigin-RevId: 177799252
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc11
-rw-r--r--tensorflow/core/BUILD7
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc11
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc21
-rw-r--r--tensorflow/core/kernels/svd_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/util/cuda_device_functions.h418
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h837
-rw-r--r--tensorflow/core/util/cuda_kernel_helper_test.cu.cc12
-rw-r--r--tensorflow/core/util/cuda_launch_config.h284
10 files changed, 851 insertions, 772 deletions
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
index 8e6870fadd..501cddb8c8 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
@@ -34,9 +34,9 @@ namespace functor {
__global__ void ReduceSliceDeviceKernel##reduceop( \
Cuda3DLaunchConfig config, Index indices_width, Index bound, \
const T begin, const Index *indices, const T *input, T *out) { \
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \
Index outidx = x * config.virtual_thread_count.y * \
config.virtual_thread_count.z + \
y * config.virtual_thread_count.z + z; \
@@ -68,8 +68,9 @@ namespace functor {
if (sizex * sizey * sizez == 0) { \
return; \
} \
- Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\
- ReduceSliceDeviceKernel##reduceop<T, Index>, 0, 0); \
+ Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
+ sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
+ 0, 0); \
\
ReduceSliceDeviceKernel##reduceop<T, Index> \
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 4b5f67baad..d77021c3ee 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1847,6 +1847,13 @@ cc_library(
],
)
+tf_cuda_library(
+ name = "cuda_device_functions",
+ hdrs = ["util/cuda_device_functions.h"],
+ visibility = ["//visibility:public"],
+ deps = [":framework_lite"],
+)
+
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
cc_library(
name = "protos_cc",
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index 42f3db1d79..f9a207208a 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -173,19 +173,13 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// Accumulate the results in the shared memory into the first element.
// No syncthreads is needed since this is only in the same warp.
int32 thread_index = threadIdx.x;
- if (thread_index < 16) {
- s_data[thread_index] += s_data[thread_index + 16];
- __syncwarp(0xFFFF);
- if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
- __syncwarp(0xFF);
- if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
- __syncwarp(0xF);
- if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
- __syncwarp(0x3);
+ if (thread_index < 32) {
+ AccT data = s_data[thread_index];
+ for (int32 offset = warpSize / 2; offset > 0; offset /= 2) {
+ data += CudaShuffleDownSync(kCudaWarpAll, data, offset);
+ }
if (thread_index == 0) {
- T val = T(s_data[0] + s_data[1]);
- // The first thread writes out the accumulated result to global location.
- CudaAtomicAdd(bias_backprop + bias_index, val);
+ CudaAtomicAdd(bias_backprop + bias_index, T(data));
}
}
}
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 903aac5d68..de0bf84c8b 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -34,6 +34,7 @@ limitations under the License.
namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
using Eigen::GpuDevice;
// Returns whether depthwise convolution forward or backward input pass can be
@@ -1028,7 +1029,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
int zeros = sub_warp * kWidth;
unsigned mask = ((1UL << kWidth) - 1) << zeros;
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
- val += CudaShuffleXor(mask, val, delta);
+ val += CudaShuffleXorSync(mask, val, delta);
}
return val;
}
@@ -1145,7 +1146,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
+ unsigned active_threads = CudaBallotSync(kCudaWarpAll, depth_in_range);
if (depth_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1159,7 +1160,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
- val += CudaShuffleDown(active_threads, val, delta);
+ val += CudaShuffleDownSync(active_threads, val, delta);
}
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
@@ -1399,7 +1400,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
+ unsigned active_threads = CudaBallotSync(kCudaWarpAll, slice_in_range);
if (slice_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1413,7 +1414,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
- val += CudaShuffleDown(active_threads, val, delta);
+ val += CudaShuffleDownSync(active_threads, val, delta);
}
if (!(thread_idx & 32 / kBlockSlices - 1)) {
*accum_ptr = val;
diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
index 31f74671ca..a3c21edc15 100644
--- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
@@ -55,6 +55,27 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
}
};
+// Specializations for std::complex, updating real and imaginary part
+// individually. Even though this is not an atomic op anymore, it is safe
+// because there is only one type of op per kernel.
+template <typename T>
+struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
+ std::complex<T>* out, const std::complex<T>& val) {
+ T* ptr = reinterpret_cast<T*>(out);
+ CudaAtomicAdd(ptr, val.real());
+ CudaAtomicAdd(ptr, val.imag());
+ }
+};
+
+template <typename T>
+struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
+ std::complex<T>* out, const std::complex<T>& val) {
+ LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val);
+ }
+};
+
} // namespace
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc
index dedc2da60b..8c3a58b108 100644
--- a/tensorflow/core/kernels/svd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc
@@ -63,8 +63,8 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
int64 ldu, const Scalar* M,
const Scalar* U, const Scalar* S,
Scalar* V) {
- CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
+ CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
CudaAtomicAdd(V + batch, v);
}
diff --git a/tensorflow/core/util/cuda_device_functions.h b/tensorflow/core/util/cuda_device_functions.h
new file mode 100644
index 0000000000..973a43d78f
--- /dev/null
+++ b/tensorflow/core/util/cuda_device_functions.h
@@ -0,0 +1,418 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
+#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
+
+/**
+ * Wrappers and helpers for CUDA device code.
+ *
+ * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
+ * backwards compatibility, see go/volta-porting for details.
+ * Provides atomic operations on types that aren't natively supported.
+ */
+
+#if GOOGLE_CUDA
+
+#include <algorithm>
+#include <complex>
+#include "cuda/include/cuda.h"
+#include "cuda/include/device_functions.h"
+#include "tensorflow/core/platform/types.h"
+
+#if __CUDACC_VER_MAJOR__ >= 9
+#include "cuda/include/cuda_fp16.h"
+#elif __CUDACC_VER__ >= 7050
+#include "cuda/include/cuda_fp16.h"
+#else
+#endif
+
+namespace tensorflow {
+
+namespace detail {
+
+// Helper for range-based for loop using 'delta' increments.
+// Usage: see CudaGridRange?() functions below.
+template <typename T>
+class CudaGridRange {
+ struct Iterator {
+ __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
+ __device__ T operator*() const { return index_; }
+ __device__ Iterator& operator++() {
+ index_ += delta_;
+ return *this;
+ }
+ __device__ bool operator!=(const Iterator& other) const {
+ bool greater = index_ > other.index_;
+ bool less = index_ < other.index_;
+ // Anything past an end iterator (delta_ == 0) is equal.
+ // In range-based for loops, this optimizes to 'return less'.
+ if (!other.delta_) {
+ return less;
+ }
+ if (!delta_) {
+ return greater;
+ }
+ return less || greater;
+ }
+
+ private:
+ T index_;
+ const T delta_;
+ };
+
+ public:
+ __device__ CudaGridRange(T begin, T delta, T end)
+ : begin_(begin), delta_(delta), end_(end) {}
+
+ __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
+ __device__ Iterator end() const { return Iterator{end_, 0}; }
+
+ private:
+ T begin_;
+ T delta_;
+ T end_;
+};
+
+} // namespace detail
+
+// Helper to visit indices in the range 0 <= i < count, using the x-coordinate
+// of the global thread index. That is, each index i is visited by all threads
+// with the same x-coordinate.
+// Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
+ return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
+ gridDim.x * blockDim.x, count);
+}
+
+// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
+// Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
+ return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
+ gridDim.y * blockDim.y, count);
+}
+
+// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
+// Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
+ return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
+ gridDim.z * blockDim.z, count);
+}
+
+// Mask for all 32 threads in a warp.
+const unsigned kCudaWarpAll = 0xffffffff;
+
+// On sm_6x and earlier, verifies that all bits in mask corresponding to active
+// threads of the warp are set. It does not verify the converse (bits of
+// inactive threads are not set), because all syncs are unblocked when a thread
+// exits the kernel, but the ballot of inactive (including exited) threads
+// returns 0.
+__device__ inline void CudaVerifySyncMask(unsigned mask) {
+#if __CUDA_ARCH__ < 700
+ assert(0 == (__ballot(1) & ~mask)); // Active threads must have mask bit set.
+#endif
+}
+
+// For all *_sync wrappers below, it is illegal to synchronize threads from
+// different program locations, because that is not supported before sm_70.
+// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
+
+// Wrapper for __syncwarp.
+__device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ __syncwarp(mask);
+#endif
+}
+
+// Wrapper for __ballot_sync.
+__device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ return __ballot_sync(mask, pred);
+#else
+ return __ballot(pred);
+#endif
+}
+
+// Wrapper for __any_sync.
+__device__ inline int CudaAnySync(unsigned mask, int pred) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ return __any_sync(mask, pred);
+#else
+ return __any(pred);
+#endif
+}
+
+// Wrapper for __all_sync.
+__device__ inline int CudaAllSync(unsigned mask, int pred) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ return __all_sync(mask, pred);
+#else
+ return __all(pred);
+#endif
+}
+
+// Wrapper for __shfl_sync.
+template <typename T>
+__device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
+ int width = warpSize) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ return __shfl_sync(mask, value, src_lane, width);
+#else
+ return __shfl(value, src_lane, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleSync(unsigned mask, double value,
+ int src_lane, int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleSync(mask, hi, src_lane, width);
+ lo = CudaShuffleSync(mask, lo, src_lane, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_up_sync.
+template <typename T>
+__device__ inline T CudaShuffleUpSync(unsigned mask, T value, int delta,
+ int width = warpSize) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ return __shfl_up_sync(mask, value, delta, width);
+#else
+ return __shfl_up(value, delta, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleUpSync(unsigned mask, double value,
+ int delta, int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleUpSync(mask, hi, delta, width);
+ lo = CudaShuffleUpSync(mask, lo, delta, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_down_sync.
+template <typename T>
+__device__ inline T CudaShuffleDownSync(unsigned mask, T value, int delta,
+ int width = warpSize) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ return __shfl_down_sync(mask, value, delta, width);
+#else
+ return __shfl_down(value, delta, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleDownSync(unsigned mask, double value,
+ int delta, int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleDownSync(mask, hi, delta, width);
+ lo = CudaShuffleDownSync(mask, lo, delta, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_xor_sync.
+template <typename T>
+__device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
+ int width = warpSize) {
+ CudaVerifySyncMask(mask);
+#if CUDA_VERSION >= 9000
+ return __shfl_xor_sync(mask, value, lane_mask, width);
+#else
+ return __shfl_xor(value, lane_mask, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleXorSync(unsigned mask, double value,
+ int lane_mask,
+ int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
+ lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __ldg.
+template <typename T>
+__host__ __device__ T CudaLdg(const T* address) {
+#if __CUDA_ARCH__ >= 350
+ return __ldg(address);
+#else
+ return *address;
+#endif
+}
+
+__host__ __device__ inline bool CudaLdg(const bool* address) {
+ return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
+}
+
+__host__ __device__ inline std::complex<float> CudaLdg(
+ const std::complex<float>* address) {
+#if __CUDA_ARCH__ >= 350
+ float2 mem = __ldg(reinterpret_cast<const float2*>(address));
+ return std::complex<float>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
+__host__ __device__ inline std::complex<double> CudaLdg(
+ const std::complex<double>* address) {
+#if __CUDA_ARCH__ >= 350
+ double2 mem = __ldg(reinterpret_cast<const double2*>(address));
+ return std::complex<double>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
+// Zeroes count elements starting at ptr using all threads of a 1-D grid.
+// Note: this function does not synchronize, and therefore the memory range is
+// not guaranteed to be zero until the next kernel launch.
+template <typename T>
+__global__ void SetZero(const int count, T* ptr) {
+ // Check that the grid is one dimensional and index doesn't overflow.
+ assert(blockDim.y == 1 && blockDim.z == 1);
+ assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
+ for (int i : CudaGridRangeX(count)) {
+ ptr[i] = T(0);
+ }
+}
+
+namespace detail {
+// Helper function for atomic accumulation implemented as CAS.
+template <typename T, typename F>
+__device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
+ T old = *ptr;
+ T assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(ptr, assumed, accumulate(assumed));
+ } while (assumed != old);
+ return old;
+}
+
+// Overload for floating point (using integer comparison to handle NaN
+// correctly).
+template <typename F>
+__device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
+ return __float_as_int(
+ CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
+ return __float_as_int(accumulate(__int_as_float(a)));
+ }));
+}
+template <typename F>
+__device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
+ return __longlong_as_double(CudaAtomicCasHelper(
+ reinterpret_cast<tensorflow::uint64*>(ptr),
+ [accumulate](tensorflow::uint64 a) {
+ return __double_as_longlong(accumulate(__longlong_as_double(a)));
+ }));
+}
+} // namespace detail
+
+// CUDA provides atomic ops, but not for all types. We provide wrappers
+// for some ops and provide implementation for all reasonable types.
+
+template <typename T>
+__device__ T CudaAtomicAdd(T* ptr, T value) {
+ return atomicAdd(ptr, value);
+}
+#if __CUDA_ARCH__ < 600
+__device__ inline double CudaAtomicAdd(double* ptr, double value) {
+ return detail::CudaAtomicCasHelper(ptr,
+ [value](double a) { return a + value; });
+}
+#elif __clang__
+// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
+// see https://reviews.llvm.org/D39638
+__device__ inline double CudaAtomicAdd(double* ptr, double value) {
+ double result;
+ asm volatile("atom.add.f64 %0, [%1], %2;"
+ : "=d"(result)
+ : "l"(ptr), "d"(value)
+ : "memory");
+ return result;
+}
+#endif
+
+template <typename T>
+__device__ T CudaAtomicSub(T* ptr, T value) {
+ return atomicSub(ptr, value);
+}
+// Specializations of substraction which add the negative value.
+__device__ inline float CudaAtomicSub(float* ptr, float value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+__device__ inline double CudaAtomicSub(double* ptr, double value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+__device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
+ tensorflow::uint64 value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+
+template <typename T>
+__device__ T CudaAtomicMax(T* ptr, T value) {
+ return atomicMax(ptr, value);
+}
+#if __CUDA_ARCH__ < 320
+__device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
+ tensorflow::uint64 value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](tensorflow::uint64 a) { return max(a, value); });
+}
+#endif
+
+template <typename T>
+__device__ inline T CudaAtomicMul(T* ptr, T value) {
+ return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
+}
+template <typename T>
+__device__ inline T CudaAtomicDiv(T* ptr, T value) {
+ return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index cf11f419a4..b71218d73c 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -18,299 +18,125 @@ limitations under the License.
#if GOOGLE_CUDA
-#include <algorithm>
+#include "tensorflow/core/util/cuda_device_functions.h"
+#include "tensorflow/core/util/cuda_launch_config.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "cuda/include/cuda.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor.h"
-#include "tensorflow/core/platform/types.h"
+// Deprecated, use 'for(int i : CudaGridRangeX(n))' instead.
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i : ::tensorflow::CudaGridRangeX<int>(n))
+// Deprecated, use 'for(int i : CudaGridRange?(n))' instead.
+#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
+ for (int i : ::tensorflow::CudaGridRange##axis<int>(n))
-// Mask for all 32 threads in a warp.
-#define CUDA_WARP_ALL 0xFFFFFFFF
-
-#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
-// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive
-// that operates at the warp-scope. This is required to ensure visibility of
-// reads/writes among threads that can make indepenent progress on Volta.
-// For previous CUDA versions these synchronizations not necessary, and we
-// define an empty function as a convenience for backward compatibility.
-__device__ inline void __syncwarp(unsigned mask = CUDA_WARP_ALL) {}
-
-// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in
-// favor of synchronizing versions. These ensure that all warp lanes specified
-// in mask execute the intrinsic in convergence. Here we provide legacy mappings
-// to the less-verbose routines provided in previous versions of CUDA.
-#define __ballot_sync(mask, predicate) __ballot(predicate)
-#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width)
-#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
-#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width)
-#define __shfl_xor_sync(mask, val, laneMask, width) \
- __shfl_xor(val, laneMask, width)
-#endif
-
-// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
-// GetCuda3DLaunchConfig:
-//
-// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
-// version uses heuristics without any knowledge of the device kernel, the other
-// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
-// launch parameters that maximize occupancy. Currently, only the maximum
-// occupancy version of GetCuda3DLaunchConfig is available.
-//
-// For large number of work elements, the convention is that each kernel would
-// iterate through its assigned range. The return value of GetCudaLaunchConfig
-// is struct CudaLaunchConfig, which contains all the information needed for the
-// kernel launch, including: virtual number of threads, the number of threads
-// per block and number of threads per block used inside <<< >>> of a kernel
-// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
-// as CudaLaunchConfig. The only difference is the dimension. The macros
-// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
-//
-/* Sample code:
-
-__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
- CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
- do_your_job_here;
- }
+namespace tensorflow {
+template <typename T>
+__host__ __device__ inline T ldg(const T* ptr) {
+ return CudaLdg(ptr);
}
-__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
- do_your_job_here;
- }
- }
+template <typename T>
+__host__ __device__ inline const T& tf_min(const T& x, const T& y) {
+ return x < y ? x : y;
}
-__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
- do_your_job_here;
- }
- }
- }
+template <typename T>
+__host__ __device__ inline const T& tf_max(const T& x, const T& y) {
+ return x < y ? y : x;
}
-void MyDriverFunc(const GPUDevice &d) {
- // use heuristics
- CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
- MyKernel1D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
- Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
- MyKernel2D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
- Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
- MyKernel3D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
-
- // maximize occupancy
- CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
- MyKernel1D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
- Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
- MyKernel1D, 0, 0);
- MyKernel2D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
- Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
- MyKernel1D, 0, 0);
- MyKernel3D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
+// Overloads of the above functions for float and double.
+__host__ __device__ inline float tf_min(float x, float y) {
+ return fminf(x, y);
}
-
-// See the test for this for more example:
-//
-https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
-
-*/
-
-#define CUDA_1D_KERNEL_LOOP(i, n) \
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
- i += blockDim.x * gridDim.x)
-
-#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
- for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \
- i += blockDim.axis * gridDim.axis)
-
-#define DIV_UP(a, b) (((a) + (b)-1) / (b))
-
-namespace tensorflow {
-
-typedef Eigen::GpuDevice GPUDevice;
-
-struct CudaLaunchConfig {
- // Logical number of thread that works on the elements. If each logical
- // thread works on exactly a single element, this is the same as the working
- // element count.
- int virtual_thread_count = -1;
- // Number of threads per block.
- int thread_per_block = -1;
- // Number of blocks for Cuda kernel launch.
- int block_count = -1;
-};
-
-// Calculate the Cuda launch config we should use for a kernel launch.
-// This is assuming the kernel is quite simple and will largely be
-// memory-limited.
-// REQUIRES: work_element_count > 0.
-inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
- const GPUDevice& d) {
- CHECK_GT(work_element_count, 0);
- CudaLaunchConfig config;
- const int virtual_thread_count = work_element_count;
- const int physical_thread_count = std::min(
- d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
- virtual_thread_count);
- const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
- const int block_count =
- std::min(DIV_UP(physical_thread_count, thread_per_block),
- d.getNumCudaMultiProcessors());
-
- config.virtual_thread_count = virtual_thread_count;
- config.thread_per_block = thread_per_block;
- config.block_count = block_count;
- return config;
+__host__ __device__ inline double tf_min(double x, double y) {
+ return fmin(x, y);
}
-
-// Calculate the Cuda launch config we should use for a kernel launch. This
-// variant takes the resource limits of func into account to maximize occupancy.
-// REQUIRES: work_element_count > 0.
-template <typename DeviceFunc>
-inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
- const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size,
- int block_size_limit) {
- CHECK_GT(work_element_count, 0);
- CudaLaunchConfig config;
- int block_count = 0;
- int thread_per_block = 0;
-
- cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
- &block_count, &thread_per_block, func, dynamic_shared_memory_size,
- block_size_limit);
- CHECK_EQ(err, cudaSuccess);
-
- block_count =
- std::min(block_count, DIV_UP(work_element_count, thread_per_block));
-
- config.virtual_thread_count = work_element_count;
- config.thread_per_block = thread_per_block;
- config.block_count = block_count;
- return config;
+__host__ __device__ inline float tf_max(float x, float y) {
+ return fmaxf(x, y);
+}
+__host__ __device__ inline double tf_max(double x, double y) {
+ return fmax(x, y);
}
-struct Cuda2DLaunchConfig {
- dim3 virtual_thread_count = dim3(0, 0, 0);
- dim3 thread_per_block = dim3(0, 0, 0);
- dim3 block_count = dim3(0, 0, 0);
-};
-
-inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
- const GPUDevice& d) {
- Cuda2DLaunchConfig config;
-
- if (xdim <= 0 || ydim <= 0) {
- return config;
- }
-
- const int kThreadsPerBlock = 256;
- int block_cols = std::min(xdim, kThreadsPerBlock);
- // ok to round down here and just do more loops in the kernel
- int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
-
- const int physical_thread_count =
- d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
-
- const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
-
- config.virtual_thread_count = dim3(xdim, ydim, 1);
- config.thread_per_block = dim3(block_cols, block_rows, 1);
-
- int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks);
+__device__ inline Eigen::half CudaShuffleSync(unsigned mask, Eigen::half value,
+ int src_lane,
+ int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
+}
- config.block_count = dim3(
- grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
- return config;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleUpSync(
+ unsigned mask, Eigen::half value, int delta, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
}
-// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
-// This variant takes the resource limits of func into account to maximize
-// occupancy.
-using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDownSync(
+ unsigned mask, Eigen::half value, int delta, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
+}
-template <typename DeviceFunc>
-inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
- int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size, int block_size_limit) {
- Cuda3DLaunchConfig config;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
+ unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
+}
- if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
- return config;
+namespace detail {
+// Overload of above function for half. Note that we don't have
+// atomicCAS() for anything less than 32 bits, so we need to include the
+// other 16 bits in the operation.
+//
+// This version is going to be very slow
+// under high concurrency, since most threads will be spinning on failing
+// their compare-and-swap tests. (The fact that we get false sharing on the
+// neighboring fp16 makes this even worse.) If you are doing a large reduction,
+// you are much better off with doing the intermediate steps in fp32 and then
+// switching to fp16 as late as you can in the calculations.
+//
+// Note: Assumes little endian.
+template <typename F>
+__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
+ namespace half_impl = Eigen::half_impl;
+ intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
+ if (intptr & 0x3) {
+ assert(!(intptr & 0x1));
+ // The half is in the second part of the uint32 (upper 16 bits).
+ uint32* address = reinterpret_cast<uint32*>(intptr - 2);
+ uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 a) {
+ Eigen::half acc = accumulate(
+ half_impl::__half_raw{static_cast<unsigned short>(a >> 16)});
+ uint32_t upper = static_cast<half_impl::__half_raw>(acc).x;
+ return (upper << 16) | (a & 0xffff);
+ });
+ return half_impl::__half_raw{static_cast<uint16>(result >> 16)};
+ } else {
+ // The half is in the first part of the uint32 (lower 16 bits).
+ uint32* address = reinterpret_cast<uint32*>(intptr);
+ uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 a) {
+ Eigen::half acc = accumulate(
+ half_impl::__half_raw{static_cast<unsigned short>(a & 0xffff)});
+ uint32_t lower = static_cast<half_impl::__half_raw>(acc).x;
+ return (a & 0xffff0000) | lower;
+ });
+ return half_impl::__half_raw{static_cast<uint16>(result & 0xffff)};
}
-
- int dev;
- cudaGetDevice(&dev);
- cudaDeviceProp deviceProp;
- cudaGetDeviceProperties(&deviceProp, dev);
- int xthreadlimit = deviceProp.maxThreadsDim[0];
- int ythreadlimit = deviceProp.maxThreadsDim[1];
- int zthreadlimit = deviceProp.maxThreadsDim[2];
- int xgridlimit = deviceProp.maxGridSize[0];
- int ygridlimit = deviceProp.maxGridSize[1];
- int zgridlimit = deviceProp.maxGridSize[2];
-
- int block_count = 0;
- int thread_per_block = 0;
- cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
- &block_count, &thread_per_block, func, dynamic_shared_memory_size,
- block_size_limit);
- CHECK_EQ(err, cudaSuccess);
-
-#define MIN3(a, b, c) std::min((a), std::min((b), (c)))
- int threadsx = MIN3(xdim, thread_per_block, xthreadlimit);
- int threadsy =
- MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
- int threadsz =
- MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
- zthreadlimit);
-
- int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit);
- int blocksy =
- MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit);
- int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)),
- DIV_UP(zdim, threadsz), zgridlimit);
-#undef MIN3
-
- config.virtual_thread_count = dim3(xdim, ydim, zdim);
- config.thread_per_block = dim3(threadsx, threadsy, threadsz);
- config.block_count = dim3(blocksx, blocksy, blocksz);
- return config;
}
+} // namespace detail
-template <typename DeviceFunc>
-inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
- int xdim, int ydim, const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size, int block_size_limit) {
- return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
- dynamic_shared_memory_size, block_size_limit);
+__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
+ Eigen::half value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](Eigen::half a) { return a + value; });
}
-
-// Returns a raw reference to the current cuda stream. Required by a
-// number of kernel calls (for which StreamInterface* does not work), i.e.
-// CUB and certain cublas primitives.
-inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
- const cudaStream_t* ptr = CHECK_NOTNULL(
- reinterpret_cast<const cudaStream_t*>(context->op_device_context()
- ->stream()
- ->implementation()
- ->CudaStreamMemberHack()));
- return *ptr;
+__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
+ Eigen::half value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](Eigen::half a) { return a - value; });
}
namespace cuda_helper {
-
template <typename IntType>
__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
IntType* orig = first;
@@ -330,481 +156,8 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
-
} // namespace cuda_helper
-
-template <typename T>
-__device__ __host__ inline T ldg(const T* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return __ldg(address);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline std::complex<float> ldg(
- const std::complex<float>* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- float2 mem = __ldg(reinterpret_cast<const float2*>(address));
- return std::complex<float>(mem.x, mem.y);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline std::complex<double> ldg(
- const std::complex<double>* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- double2 mem = __ldg(reinterpret_cast<const double2*>(address));
- return std::complex<double>(mem.x, mem.y);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline Eigen::half ldg(const Eigen::half* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return Eigen::half_impl::raw_uint16_to_half(
- __ldg(reinterpret_cast<const uint16_t*>(address)));
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline bool ldg(const bool* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return *reinterpret_cast<const bool*>(
- __ldg(reinterpret_cast<const char*>(address)));
-#else
- return *address;
-#endif
-}
-
-// CUDA provides atomic ops, but not for all types. We provide wrappers
-// for some ops and provide implementation for all reasonable types.
-#define CUDA_ATOMIC_WRAPPER(op, T) \
- __device__ __forceinline__ T CudaAtomic##op(T* address, T val)
-
-#define USE_CUDA_ATOMIC(op, T) \
- CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
-
-// For atomicAdd.
-USE_CUDA_ATOMIC(Add, int32);
-USE_CUDA_ATOMIC(Add, uint32);
-USE_CUDA_ATOMIC(Add, uint64);
-USE_CUDA_ATOMIC(Add, float);
-
-// For atomicMax.
-USE_CUDA_ATOMIC(Max, int32);
-USE_CUDA_ATOMIC(Max, uint32);
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
-USE_CUDA_ATOMIC(Max, uint64);
-#else
-// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >=
-// 350. If not satisfied, we provide a custom implementation using atomicCAS().
-CUDA_ATOMIC_WRAPPER(Max, uint64) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed, max(val, assumed));
- } while (assumed != old);
-
- return old;
-}
-#endif
-
-// Custom implementation of atomicAdd for double.
-// This implementation is copied from CUDA manual.
-CUDA_ATOMIC_WRAPPER(Add, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(val + __longlong_as_double(assumed)));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- return __longlong_as_double(old);
-}
-
-// Custom implementation of atomicAdd for std::complex<float>.
-// This implementation performs to atomic additions on the components.
-CUDA_ATOMIC_WRAPPER(Add, std::complex<float>) {
-#if defined(__CUDA_ARCH__)
-#if __CUDA_ARCH__ >= 350
- float2* addr_as_float2 = reinterpret_cast<float2*>(address);
- float2* val_as_float2 = reinterpret_cast<float2*>(&val);
- CudaAtomicAdd(&(addr_as_float2->x), val_as_float2->x);
- CudaAtomicAdd(&(addr_as_float2->y), val_as_float2->y);
-#else
- static_assert(sizeof(std::complex<float>) == 2 * sizeof(float),
- "Unable to compile CudaAtomicAdd for complex64 because "
- "sizeof(complex64) != 2*sizeof(float32)");
- float* addr_as_float = reinterpret_cast<float*>(address);
- float* val_as_float = reinterpret_cast<float*>(&val);
- CudaAtomicAdd(addr_as_float, *val_as_float);
- CudaAtomicAdd(addr_as_float + 1, *(val_as_float + 1));
-#endif
-#endif
- return *address;
-}
-
-// Custom implementation of atomicAdd for std::complex<double>.
-// This implementation performs to atomic additions on the components
-// using the double atomic wrapper above.
-CUDA_ATOMIC_WRAPPER(Add, complex128) {
-#if defined(__CUDA_ARCH__)
-#if __CUDA_ARCH__ >= 350
- double2* addr_as_double2 = reinterpret_cast<double2*>(address);
- double2* val_as_double2 = reinterpret_cast<double2*>(&val);
- CudaAtomicAdd(&(addr_as_double2->x), val_as_double2->x);
- CudaAtomicAdd(&(addr_as_double2->y), val_as_double2->y);
-#else
- static_assert(sizeof(std::complex<double>) == 2 * sizeof(double),
- "Unable to compile CudaAtomicAdd for complex128 because "
- "sizeof(complex128) != 2*sizeof(float64)");
- double* addr_as_double = reinterpret_cast<double*>(address);
- double* val_as_double = reinterpret_cast<double*>(&val);
- CudaAtomicAdd(addr_as_double, *val_as_double);
- CudaAtomicAdd(addr_as_double + 1, *(val_as_double + 1));
-#endif
-#endif
- return *address;
-}
-
-// Helper functions for CudaAtomicAdd(half*, half), below.
-//
-// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2()
-// for a more efficient implementation, assuming that adding -0.0
-// will never harm the neighboring value. In this version, we take special
-// care to guarantee the bits of the untouched value are unchanged.
-inline __device__ uint32 add_to_low_half(uint32 val, float x) {
- Eigen::half low_half;
- low_half.x = static_cast<uint16>(val & 0xffffu);
- low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x);
- return (val & 0xffff0000u) | low_half.x;
-}
-
-inline __device__ uint32 add_to_high_half(uint32 val, float x) {
- Eigen::half high_half;
- high_half.x = static_cast<uint16>(val >> 16);
- high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x);
- return (val & 0xffffu) | (high_half.x << 16);
-}
-
-// Custom implementation of atomicAdd for half. Note that we don't have
-// atomicCAS() for anything less than 32 bits, so we need to include the
-// other 16 bits in the operation.
-//
-// Unlike the other atomic adds, this version is going to be very slow
-// under high concurrency, since most threads will be spinning on failing
-// their compare-and-swap tests. (The fact that we get false sharing on the
-// neighboring fp16 makes this even worse.) If you are doing a large reduction,
-// you are much better off with doing the intermediate steps in fp32 and then
-// switching to fp16 as late as you can in the calculations.
-//
-// Note: Assumes little endian.
-CUDA_ATOMIC_WRAPPER(Add, Eigen::half) {
- float val_as_float(val);
- intptr_t address_int = reinterpret_cast<intptr_t>(address);
- if ((address_int & 0x2) == 0) {
- // The half is in the first part of the uint32 (lower 16 bits).
- uint32* address_as_uint32 = reinterpret_cast<uint32*>(address);
- assert(((intptr_t)address_as_uint32 & 0x3) == 0);
- uint32 old = *address_as_uint32, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_uint32, assumed,
- add_to_low_half(assumed, val_as_float));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- Eigen::half ret;
- ret.x = old & 0xffffu;
- return ret;
- } else {
- // The half is in the second part of the uint32 (upper 16 bits).
- uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2);
- assert(((intptr_t)address_as_uint32 & 0x3) == 0);
- uint32 old = *address_as_uint32, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_uint32, assumed,
- add_to_high_half(assumed, val_as_float));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- Eigen::half ret;
- ret.x = old >> 16;
- return ret;
- }
-}
-
-template <typename T>
-__global__ void SetZero(const int nthreads, T* bottom_diff) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); }
-}
-
-// For atomicSub.
-
-// Custom implementation for sub by just negating the value.
-#define WRAPPED_ATOMIC_SUB(T) \
- CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); }
-
-WRAPPED_ATOMIC_SUB(uint64);
-WRAPPED_ATOMIC_SUB(int32);
-WRAPPED_ATOMIC_SUB(uint32);
-WRAPPED_ATOMIC_SUB(Eigen::half);
-WRAPPED_ATOMIC_SUB(float);
-WRAPPED_ATOMIC_SUB(double);
-
-CUDA_ATOMIC_WRAPPER(Sub, complex64) {
- const std::complex<float> Tneg(-val.real(), -val.imag());
- return CudaAtomicAdd(address, Tneg);
-}
-
-CUDA_ATOMIC_WRAPPER(Sub, complex128) {
- const std::complex<double> Tneg(-val.real(), -val.imag());
- return CudaAtomicAdd(address, Tneg);
-}
-
-#undef WRAPPED_ATOMIC_SUB
-
-// For atomicMul.
-CUDA_ATOMIC_WRAPPER(Mul, int32) {
- int32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, uint32) {
- uint32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, uint64) {
- uint64 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, float) {
- int32* address_as_int = reinterpret_cast<int32*>(address);
- int32 old = *address_as_int, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_int, assumed,
- __float_as_int(val * __int_as_float(assumed)));
- } while (assumed != old);
- return __int_as_float(old);
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(val * __longlong_as_double(assumed)));
- } while (assumed != old);
- return __longlong_as_double(old);
-}
-
-// For atomicDiv.
-CUDA_ATOMIC_WRAPPER(Div, int32) {
- int32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, uint32) {
- uint32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, uint64) {
- uint64 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, float) {
- int32* address_as_int = reinterpret_cast<int32*>(address);
- int32 old = *address_as_int, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_int, assumed,
- __float_as_int(__int_as_float(assumed) / val));
- } while (assumed != old);
- return __int_as_float(old);
-}
-
-CUDA_ATOMIC_WRAPPER(Div, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(__longlong_as_double(assumed) / val));
- } while (assumed != old);
- return __longlong_as_double(old);
-}
-
-#undef USE_CUDA_ATOMIC
-#undef CUDA_ATOMIC_WRAPPER
-
-template <typename T>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) {
- return x > y ? y : x;
-}
-
-template <typename T>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
- return x < y ? y : x;
-}
-
-__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask,
- int predicate) {
- return __ballot_sync(mask, predicate);
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value,
- int srcLane,
- int width = warpSize) {
- return __shfl_sync(mask, value, srcLane, width);
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, double value,
- int srcLane,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_sync(mask, hi, srcLane, width);
- lo = __shfl_sync(mask, lo, srcLane, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, T value,
- int delta,
- int width = warpSize) {
- return __shfl_up_sync(mask, value, delta, width);
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, double value,
- int delta,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_up_sync(mask, hi, delta, width);
- lo = __shfl_up_sync(mask, lo, delta, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value,
- int delta,
- int width = warpSize) {
- return __shfl_down_sync(mask, value, delta, width);
-}
-
-__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown(
- unsigned mask, Eigen::half value, int delta, int width = warpSize) {
- return Eigen::half(
- __shfl_down_sync(mask, static_cast<uint16>(value), delta, width));
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask,
- double value, int delta,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_down_sync(mask, hi, delta, width);
- lo = __shfl_down_sync(mask, lo, delta, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value,
- int laneMask,
- int width = warpSize) {
- return __shfl_xor_sync(mask, value, laneMask, width);
-}
-
-__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor(
- unsigned mask, Eigen::half value, int laneMask, int width = warpSize) {
- return Eigen::half(
- __shfl_xor_sync(mask, static_cast<uint16>(value), laneMask, width));
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask,
- double value, int laneMask,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_xor_sync(mask, hi, laneMask, width);
- lo = __shfl_xor_sync(mask, lo, laneMask, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
} // namespace tensorflow
-#undef DIV_UP
-
#endif // GOOGLE_CUDA
-
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
index 6991554eff..4eb1558e58 100644
--- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
+++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
@@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
@@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) {
if (z < 0) { // z might overflow when testing extreme case
break;
}
@@ -94,7 +94,7 @@ class CudaLaunchConfigTest : public ::testing::Test {
const int bufsize = 1024;
int* outbuf = nullptr;
Eigen::CudaStreamDevice stream;
- GPUDevice d = GPUDevice(&stream);
+ Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
virtual void SetUp() {
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h
new file mode 100644
index 0000000000..3ea33ee6cf
--- /dev/null
+++ b/tensorflow/core/util/cuda_launch_config.h
@@ -0,0 +1,284 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
+#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
+
+#if GOOGLE_CUDA
+
+#include <algorithm>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "cuda/include/cuda.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/platform/types.h"
+
+// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
+// GetCuda3DLaunchConfig:
+//
+// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
+// version uses heuristics without any knowledge of the device kernel, the other
+// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
+// launch parameters that maximize occupancy. Currently, only the maximum
+// occupancy version of GetCuda3DLaunchConfig is available.
+//
+// For large number of work elements, the convention is that each kernel would
+// iterate through its assigned range. The return value of GetCudaLaunchConfig
+// is struct CudaLaunchConfig, which contains all the information needed for the
+// kernel launch, including: virtual number of threads, the number of threads
+// per block and number of threads per block used inside <<< >>> of a kernel
+// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
+// as CudaLaunchConfig. The only difference is the dimension. The macros
+// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
+//
+/* Sample code:
+
+__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
+ CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
+ do_your_job_here;
+ }
+}
+
+__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ do_your_job_here;
+ }
+ }
+}
+
+__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
+ do_your_job_here;
+ }
+ }
+ }
+}
+
+void MyDriverFunc(const Eigen::GpuDevice &d) {
+ // use heuristics
+ CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
+ MyKernel1D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
+ Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
+ MyKernel2D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
+ Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
+ MyKernel3D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
+
+ // maximize occupancy
+ CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
+ MyKernel1D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
+ Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
+ MyKernel1D, 0, 0);
+ MyKernel2D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
+ Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
+ MyKernel1D, 0, 0);
+ MyKernel3D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
+}
+
+// See the test for this for more example:
+//
+https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
+
+*/
+
+namespace tensorflow {
+
+inline int DivUp(int a, int b) { return (a + b - 1) / b; }
+
+struct CudaLaunchConfig {
+ // Logical number of thread that works on the elements. If each logical
+ // thread works on exactly a single element, this is the same as the working
+ // element count.
+ int virtual_thread_count = -1;
+ // Number of threads per block.
+ int thread_per_block = -1;
+ // Number of blocks for Cuda kernel launch.
+ int block_count = -1;
+};
+
+// Calculate the Cuda launch config we should use for a kernel launch.
+// This is assuming the kernel is quite simple and will largely be
+// memory-limited.
+// REQUIRES: work_element_count > 0.
+inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
+ const Eigen::GpuDevice& d) {
+ CHECK_GT(work_element_count, 0);
+ CudaLaunchConfig config;
+ const int virtual_thread_count = work_element_count;
+ const int physical_thread_count = std::min(
+ d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
+ virtual_thread_count);
+ const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
+ const int block_count =
+ std::min(DivUp(physical_thread_count, thread_per_block),
+ d.getNumCudaMultiProcessors());
+
+ config.virtual_thread_count = virtual_thread_count;
+ config.thread_per_block = thread_per_block;
+ config.block_count = block_count;
+ return config;
+}
+
+// Calculate the Cuda launch config we should use for a kernel launch. This
+// variant takes the resource limits of func into account to maximize occupancy.
+// REQUIRES: work_element_count > 0.
+template <typename DeviceFunc>
+inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
+ const Eigen::GpuDevice& d,
+ DeviceFunc func,
+ size_t dynamic_shared_memory_size,
+ int block_size_limit) {
+ CHECK_GT(work_element_count, 0);
+ CudaLaunchConfig config;
+ int block_count = 0;
+ int thread_per_block = 0;
+
+ cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
+ &block_count, &thread_per_block, func, dynamic_shared_memory_size,
+ block_size_limit);
+ CHECK_EQ(err, cudaSuccess);
+
+ block_count =
+ std::min(block_count, DivUp(work_element_count, thread_per_block));
+
+ config.virtual_thread_count = work_element_count;
+ config.thread_per_block = thread_per_block;
+ config.block_count = block_count;
+ return config;
+}
+
+struct Cuda2DLaunchConfig {
+ dim3 virtual_thread_count = dim3(0, 0, 0);
+ dim3 thread_per_block = dim3(0, 0, 0);
+ dim3 block_count = dim3(0, 0, 0);
+};
+
+inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
+ const Eigen::GpuDevice& d) {
+ Cuda2DLaunchConfig config;
+
+ if (xdim <= 0 || ydim <= 0) {
+ return config;
+ }
+
+ const int kThreadsPerBlock = 256;
+ int block_cols = std::min(xdim, kThreadsPerBlock);
+ // ok to round down here and just do more loops in the kernel
+ int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
+
+ const int physical_thread_count =
+ d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
+
+ const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
+
+ config.virtual_thread_count = dim3(xdim, ydim, 1);
+ config.thread_per_block = dim3(block_cols, block_rows, 1);
+
+ int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
+
+ config.block_count = dim3(
+ grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
+ return config;
+}
+
+// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
+// This variant takes the resource limits of func into account to maximize
+// occupancy.
+using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
+
+template <typename DeviceFunc>
+inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
+ int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func,
+ size_t dynamic_shared_memory_size, int block_size_limit) {
+ Cuda3DLaunchConfig config;
+
+ if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
+ return config;
+ }
+
+ int dev;
+ cudaGetDevice(&dev);
+ cudaDeviceProp deviceProp;
+ cudaGetDeviceProperties(&deviceProp, dev);
+ int xthreadlimit = deviceProp.maxThreadsDim[0];
+ int ythreadlimit = deviceProp.maxThreadsDim[1];
+ int zthreadlimit = deviceProp.maxThreadsDim[2];
+ int xgridlimit = deviceProp.maxGridSize[0];
+ int ygridlimit = deviceProp.maxGridSize[1];
+ int zgridlimit = deviceProp.maxGridSize[2];
+
+ int block_count = 0;
+ int thread_per_block = 0;
+ cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
+ &block_count, &thread_per_block, func, dynamic_shared_memory_size,
+ block_size_limit);
+ CHECK_EQ(err, cudaSuccess);
+
+ auto min3 = [](int a, int b, int c) { return std::min(a, std::min(b, c)); };
+
+ int threadsx = min3(xdim, thread_per_block, xthreadlimit);
+ int threadsy =
+ min3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
+ int threadsz =
+ min3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
+ zthreadlimit);
+
+ int blocksx = min3(block_count, DivUp(xdim, threadsx), xgridlimit);
+ int blocksy =
+ min3(DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit);
+ int blocksz = min3(DivUp(block_count, (blocksx * blocksy)),
+ DivUp(zdim, threadsz), zgridlimit);
+
+ config.virtual_thread_count = dim3(xdim, ydim, zdim);
+ config.thread_per_block = dim3(threadsx, threadsy, threadsz);
+ config.block_count = dim3(blocksx, blocksy, blocksz);
+ return config;
+}
+
+template <typename DeviceFunc>
+inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
+ int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func,
+ size_t dynamic_shared_memory_size, int block_size_limit) {
+ return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
+ dynamic_shared_memory_size, block_size_limit);
+}
+
+// Returns a raw reference to the current cuda stream. Required by a
+// number of kernel calls (for which StreamInterface* does not work), i.e.
+// CUB and certain cublas primitives.
+inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
+ const cudaStream_t* ptr = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+ return *ptr;
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_