diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-26 05:15:18 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-26 05:22:51 -0800 |
commit | abdc62aee1eeba32be56d761a2f9988306356084 (patch) | |
tree | d49fa0848904b212443245702923255ba18cca58 | |
parent | c8c2e4932afccb594bfe05e22facea1aba9dd454 (diff) |
Roll CL 179861781 forward with fix: Wrappers for CUDA 9 warp-synchronous intrinsics.
PiperOrigin-RevId: 183374082
-rw-r--r-- | tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/bias_op_gpu.cu.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/kernels/svd_op_gpu.cu.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/util/cuda_device_functions.h | 499 | ||||
-rw-r--r-- | tensorflow/core/util/cuda_kernel_helper.h | 857 | ||||
-rw-r--r-- | tensorflow/core/util/cuda_kernel_helper_test.cu.cc | 60 | ||||
-rw-r--r-- | tensorflow/core/util/cuda_launch_config.h | 284 |
10 files changed, 988 insertions, 786 deletions
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc index 8e6870fadd..501cddb8c8 100644 --- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc +++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc @@ -34,9 +34,9 @@ namespace functor { __global__ void ReduceSliceDeviceKernel##reduceop( \ Cuda3DLaunchConfig config, Index indices_width, Index bound, \ const T begin, const Index *indices, const T *input, T *out) { \ - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \ - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \ - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \ + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \ + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \ + CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \ Index outidx = x * config.virtual_thread_count.y * \ config.virtual_thread_count.z + \ y * config.virtual_thread_count.z + z; \ @@ -68,8 +68,9 @@ namespace functor { if (sizex * sizey * sizez == 0) { \ return; \ } \ - Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\ - ReduceSliceDeviceKernel##reduceop<T, Index>, 0, 0); \ + Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \ + sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \ + 0, 0); \ \ ReduceSliceDeviceKernel##reduceop<T, Index> \ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 94973a0e52..29c515121e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1896,6 +1896,13 @@ cc_library( ], ) +tf_cuda_library( + name = "cuda_device_functions", + hdrs = ["util/cuda_device_functions.h"], + visibility = ["//visibility:public"], + deps = [":framework_lite"], +) + # TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"? cc_library( name = "protos_cc", diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 42f3db1d79..2ca194a77f 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -173,19 +173,13 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, // Accumulate the results in the shared memory into the first element. // No syncthreads is needed since this is only in the same warp. int32 thread_index = threadIdx.x; - if (thread_index < 16) { - s_data[thread_index] += s_data[thread_index + 16]; - __syncwarp(0xFFFF); - if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8]; - __syncwarp(0xFF); - if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4]; - __syncwarp(0xF); - if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2]; - __syncwarp(0x3); + if (thread_index < 32) { + AccT data = s_data[thread_index]; + for (int32 delta = warpSize / 2; delta > 0; delta /= 2) { + data += CudaShuffleXorSync(kCudaWarpAll, data, delta); + } if (thread_index == 0) { - T val = T(s_data[0] + s_data[1]); - // The first thread writes out the accumulated result to global location. - CudaAtomicAdd(bias_backprop + bias_index, val); + CudaAtomicAdd(bias_backprop + bias_index, T(data)); } } } diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 903aac5d68..5493e33532 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -34,6 +34,7 @@ limitations under the License. namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; using Eigen::GpuDevice; // Returns whether depthwise convolution forward or backward input pass can be @@ -1028,7 +1029,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) { int zeros = sub_warp * kWidth; unsigned mask = ((1UL << kWidth) - 1) << zeros; for (int delta = kWidth / 2; delta > 0; delta /= 2) { - val += CudaShuffleXor(mask, val, delta); + val += CudaShuffleXorSync(mask, val, delta); } return val; } @@ -1145,7 +1146,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range); + unsigned active_threads = CudaBallotSync(kCudaWarpAll, depth_in_range); if (depth_in_range) { const T* const out_ptr = inout_offset + output; @@ -1159,7 +1160,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. for (int delta = 16; delta >= kBlockSlices; delta /= 2) { - val += CudaShuffleDown(active_threads, val, delta); + val += CudaShuffleXorSync(active_threads, val, delta); } if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) { *accum_ptr = val; @@ -1399,7 +1400,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range); + unsigned active_threads = CudaBallotSync(kCudaWarpAll, slice_in_range); if (slice_in_range) { const T* const out_ptr = inout_offset + output; @@ -1413,10 +1414,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) { - val += CudaShuffleDown(active_threads, val, delta); + val += CudaShuffleXorSync(active_threads, val, delta); } if (!(thread_idx & 32 / kBlockSlices - 1)) { - *accum_ptr = val; + *accum_ptr = val; // kBlockSlices threads per warp. } ++shared_offset; accum_ptr += accum_increment; diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc index 31f74671ca..a3c21edc15 100644 --- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc @@ -55,6 +55,27 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> { } }; +// Specializations for std::complex, updating real and imaginary part +// individually. Even though this is not an atomic op anymore, it is safe +// because there is only one type of op per kernel. +template <typename T> +struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()( + std::complex<T>* out, const std::complex<T>& val) { + T* ptr = reinterpret_cast<T*>(out); + CudaAtomicAdd(ptr, val.real()); + CudaAtomicAdd(ptr, val.imag()); + } +}; + +template <typename T> +struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()( + std::complex<T>* out, const std::complex<T>& val) { + LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val); + } +}; + } // namespace template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM> diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc index dedc2da60b..8c3a58b108 100644 --- a/tensorflow/core/kernels/svd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc @@ -63,8 +63,8 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m, int64 ldu, const Scalar* M, const Scalar* U, const Scalar* S, Scalar* V) { - CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) { + CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) { Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch]; CudaAtomicAdd(V + batch, v); } diff --git a/tensorflow/core/util/cuda_device_functions.h b/tensorflow/core/util/cuda_device_functions.h new file mode 100644 index 0000000000..f787687f66 --- /dev/null +++ b/tensorflow/core/util/cuda_device_functions.h @@ -0,0 +1,499 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ +#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ + +/** + * Wrappers and helpers for CUDA device code. + * + * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide + * backwards compatibility, see go/volta-porting for details. + * Provides atomic operations on types that aren't natively supported. + */ + +#if GOOGLE_CUDA + +#include <algorithm> +#include <complex> +#include "cuda/include/cuda.h" +#include "cuda/include/device_functions.h" +#include "tensorflow/core/platform/types.h" + +#if CUDA_VERSION >= 7050 +#include "cuda/include/cuda_fp16.h" +#endif // CUDA_VERSION >= 7050 + +namespace tensorflow { + +namespace detail { + +// Helper for range-based for loop using 'delta' increments. +// Usage: see CudaGridRange?() functions below. +template <typename T> +class CudaGridRange { + struct Iterator { + __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {} + __device__ T operator*() const { return index_; } + __device__ Iterator& operator++() { + index_ += delta_; + return *this; + } + __device__ bool operator!=(const Iterator& other) const { + bool greater = index_ > other.index_; + bool less = index_ < other.index_; + // Anything past an end iterator (delta_ == 0) is equal. + // In range-based for loops, this optimizes to 'return less'. + if (!other.delta_) { + return less; + } + if (!delta_) { + return greater; + } + return less || greater; + } + + private: + T index_; + const T delta_; + }; + + public: + __device__ CudaGridRange(T begin, T delta, T end) + : begin_(begin), delta_(delta), end_(end) {} + + __device__ Iterator begin() const { return Iterator{begin_, delta_}; } + __device__ Iterator end() const { return Iterator{end_, 0}; } + + private: + T begin_; + T delta_; + T end_; +}; + +} // namespace detail + +// Helper to visit indices in the range 0 <= i < count, using the x-coordinate +// of the global thread index. That is, each index i is visited by all threads +// with the same x-coordinate. +// Usage: for(int i : CudaGridRangeX(count)) { visit(i); } +template <typename T> +__device__ detail::CudaGridRange<T> CudaGridRangeX(T count) { + return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x, + gridDim.x * blockDim.x, count); +} + +// Helper to visit indices in the range 0 <= i < count using the y-coordinate. +// Usage: for(int i : CudaGridRangeY(count)) { visit(i); } +template <typename T> +__device__ detail::CudaGridRange<T> CudaGridRangeY(T count) { + return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y, + gridDim.y * blockDim.y, count); +} + +// Helper to visit indices in the range 0 <= i < count using the z-coordinate. +// Usage: for(int i : CudaGridRangeZ(count)) { visit(i); } +template <typename T> +__device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) { + return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z, + gridDim.z * blockDim.z, count); +} + +// Mask for all 32 threads in a warp. +const unsigned kCudaWarpAll = 0xffffffff; + +// Returns the warp lane ID of the calling thread +__device__ inline unsigned CudaLaneId() { + unsigned int lane_id; + asm("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +namespace detail { +// Returns true if mask is a valid parameter for __shfl*sync to return a well +// defined value, assuming the calling lane will read from src_lane as part of +// the shuffle operation. +// +// Specifically, returns true iff mask has the calling lane bit and the src_lane +// bit set, and the src_lane calls this function with the same mask value +// (required for the two threads to wait for each other). +// +// On Volta, for some invalid masks, this function hangs or returns false +// positives, because the implementation shuffles with the same mask that +// we are validating. Run on Pascal if you suspect that the mask is incorrect. +__device__ inline bool CudaValidateShuffleSyncMask(unsigned mask, + unsigned src_lane) { + unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane; +#if CUDA_VERSION >= 9000 + unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane); +#else + unsigned src_lane_mask = __shfl(mask, src_lane); +#endif + return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask; +} + +// Returns the actual source lane for shuffle. +__device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) { + int lane_id = CudaLaneId(); + int lane_base = lane_id & ~width + 1; + int lane_offset = src_lane & width - 1; + return lane_base + lane_offset; +} + +// Returns the source lane for shuffle up. +__device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) { + unsigned lane_id = CudaLaneId(); + if ((lane_id & width - 1) < delta) { + return lane_id; + } + return lane_id - delta; +} + +// Returns the source lane for shuffle down. +__device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta, + int width) { + unsigned lane_id = CudaLaneId(); + if ((lane_id & width - 1) + delta >= width) { + return lane_id; + } + return lane_id + delta; +} + +// Returns the source lane for shuffle xor. +__device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) { + int lane_id = CudaLaneId(); + int src_lane = lane_id ^ lane_mask; + if (src_lane > (lane_id | width - 1)) { + return lane_id; + } + return src_lane; +} +} // namespace detail + +// For all *_sync wrappers below, it is illegal to synchronize threads from +// different program locations, because that is not supported before sm_70. +// In other words, all threads in 'mask' must call the functions in convergence. +// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly. +// +// It is also illegal to shuffle with a mask that produces an undefined result +// for any of the threads. Specifically, all source threads of the shuffle +// must have their corresponding bit in 'mask' set. + +// Wrapper for __syncwarp. No-op for CUDA 8 and earlier. +__device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + __syncwarp(mask); +#endif +} + +// Wrapper for __ballot_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline unsigned CudaBallotSync(unsigned mask, int pred) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + return __ballot_sync(mask, pred); +#else + return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec. +#endif +} + +// Wrapper for __any_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline int CudaAnySync(unsigned mask, int pred) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + return __any_sync(mask, pred); +#else + return __any(pred); +#endif +} + +// Wrapper for __all_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline int CudaAllSync(unsigned mask, int pred) { + assert(mask & 1u << CudaLaneId()); +#if CUDA_VERSION >= 9000 + return __all_sync(mask, pred); +#else + return __all(pred); +#endif +} + +// Wrapper for __shfl_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template <typename T> +__device__ T CudaShuffleSync(unsigned mask, T value, int src_lane, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleGetSrcLane(src_lane, width))); +#if CUDA_VERSION >= 9000 + return __shfl_sync(mask, value, src_lane, width); +#else + return __shfl(value, src_lane, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleSync(unsigned mask, double value, + int src_lane, int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleSync(mask, hi, src_lane, width); + lo = CudaShuffleSync(mask, lo, src_lane, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template <typename T> +__device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleUpGetSrcLane(delta, width))); +#if CUDA_VERSION >= 9000 + return __shfl_up_sync(mask, value, delta, width); +#else + return __shfl_up(value, delta, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleUpSync(unsigned mask, double value, + unsigned delta, + int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleUpSync(mask, hi, delta, width); + lo = CudaShuffleUpSync(mask, lo, delta, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __shfl_down_sync. All threads in 'mask' must call this function +// in convergence, see comment above for details. +template <typename T> +__device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleDownGetSrcLane(delta, width))); +#if CUDA_VERSION >= 9000 + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleDownSync(unsigned mask, double value, + unsigned delta, + int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleDownSync(mask, hi, delta, width); + lo = CudaShuffleDownSync(mask, lo, delta, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template <typename T> +__device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::CudaValidateShuffleSyncMask( + mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width))); +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, lane_mask, width); +#else + return __shfl_xor(value, lane_mask, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double CudaShuffleXorSync(unsigned mask, double value, + int lane_mask, + int width = warpSize) { + unsigned lo, hi; + asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); + hi = CudaShuffleXorSync(mask, hi, lane_mask, width); + lo = CudaShuffleXorSync(mask, lo, lane_mask, width); + asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); + return value; +} + +// Wrapper for __ldg. +template <typename T> +__host__ __device__ T CudaLdg(const T* address) { +#if __CUDA_ARCH__ >= 350 + return __ldg(address); +#else + return *address; +#endif +} + +__host__ __device__ inline bool CudaLdg(const bool* address) { + return CudaLdg(reinterpret_cast<const char*>(address)) != 0; +} + +__host__ __device__ inline std::complex<float> CudaLdg( + const std::complex<float>* address) { +#if __CUDA_ARCH__ >= 350 + float2 mem = __ldg(reinterpret_cast<const float2*>(address)); + return std::complex<float>(mem.x, mem.y); +#else + return *address; +#endif +} + +__host__ __device__ inline std::complex<double> CudaLdg( + const std::complex<double>* address) { +#if __CUDA_ARCH__ >= 350 + double2 mem = __ldg(reinterpret_cast<const double2*>(address)); + return std::complex<double>(mem.x, mem.y); +#else + return *address; +#endif +} + +// Zeroes count elements starting at ptr using all threads of a 1-D grid. +// Note: this function does not synchronize, and therefore the memory range is +// not guaranteed to be zero until the next kernel launch. +template <typename T> +__global__ void SetZero(const int count, T* ptr) { + // Check that the grid is one dimensional and index doesn't overflow. + assert(blockDim.y == 1 && blockDim.z == 1); + assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); + for (int i : CudaGridRangeX(count)) { + ptr[i] = T(0); + } +} + +namespace detail { +// Helper function for atomic accumulation implemented as CAS. +template <typename T, typename F> +__device__ T CudaAtomicCasHelper(T* ptr, F accumulate) { + T old = *ptr; + T assumed; + do { + assumed = old; + old = atomicCAS(ptr, assumed, accumulate(assumed)); + } while (assumed != old); + return old; +} + +// Overload for floating point (using integer comparison to handle NaN +// correctly). +template <typename F> +__device__ float CudaAtomicCasHelper(float* ptr, F accumulate) { + return __float_as_int( + CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) { + return __float_as_int(accumulate(__int_as_float(a))); + })); +} +template <typename F> +__device__ double CudaAtomicCasHelper(double* ptr, F accumulate) { + return __longlong_as_double(CudaAtomicCasHelper( + reinterpret_cast<tensorflow::uint64*>(ptr), + [accumulate](tensorflow::uint64 a) { + return __double_as_longlong(accumulate(__longlong_as_double(a))); + })); +} + +template <typename From, typename To> +using ToTypeIfConvertible = + typename std::enable_if<std::is_convertible<From, To>::value, To>::type; + +} // namespace detail + +// CUDA provides atomic ops, but not for all types. We provide wrappers +// for some ops and provide implementation for all reasonable types. + +template <typename T, typename U> +__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicAdd(T* ptr, U value) { + return atomicAdd(ptr, value); +} +#if __CUDA_ARCH__ < 600 +__device__ inline double CudaAtomicAdd(double* ptr, double value) { + return detail::CudaAtomicCasHelper(ptr, + [value](double a) { return a + value; }); +} +#elif __clang__ +// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX. +// see https://reviews.llvm.org/D39638 +__device__ inline double CudaAtomicAdd(double* ptr, double value) { + double result; + asm volatile("atom.add.f64 %0, [%1], %2;" + : "=d"(result) + : "l"(ptr), "d"(value) + : "memory"); + return result; +} +#endif + +template <typename T, typename U> +__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicSub(T* ptr, U value) { + return atomicSub(ptr, value); +} +// Specializations of substraction which add the negative value. +__device__ inline float CudaAtomicSub(float* ptr, float value) { + return CudaAtomicAdd(ptr, -value); +} +__device__ inline double CudaAtomicSub(double* ptr, double value) { + return CudaAtomicAdd(ptr, -value); +} +__device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return CudaAtomicAdd(ptr, -value); +} + +template <typename T, typename U> +__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMax(T* ptr, U value) { + return atomicMax(ptr, value); +} +#if __CUDA_ARCH__ < 320 +__device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return detail::CudaAtomicCasHelper( + ptr, [value](tensorflow::uint64 a) { return max(a, value); }); +} +#endif + +template <typename T, typename U> +__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMul(T* ptr, U value) { + return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; }); +} +template <typename T, typename U> +__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicDiv(T* ptr, U value) { + return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; }); +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 3e32ec7973..18a4c008f1 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -18,299 +18,133 @@ limitations under the License. #if GOOGLE_CUDA -#include <algorithm> +#include "tensorflow/core/util/cuda_device_functions.h" +#include "tensorflow/core/util/cuda_launch_config.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "cuda/include/cuda.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/platform/types.h" +// Deprecated, use 'for(int i : CudaGridRangeX(n))' instead. +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i : ::tensorflow::CudaGridRangeX<int>(n)) +// Deprecated, use 'for(int i : CudaGridRange?(n))' instead. +#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \ + for (int i : ::tensorflow::CudaGridRange##axis<int>(n)) -// Mask for all 32 threads in a warp. -#define CUDA_WARP_ALL 0xFFFFFFFF - -#if defined(CUDA_VERSION) && CUDA_VERSION < 9000 -// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive -// that operates at the warp-scope. This is required to ensure visibility of -// reads/writes among threads that can make indepenent progress on Volta. -// For previous CUDA versions these synchronizations not necessary, and we -// define an empty function as a convenience for backward compatibility. -__device__ inline void __syncwarp(unsigned mask = CUDA_WARP_ALL) {} - -// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in -// favor of synchronizing versions. These ensure that all warp lanes specified -// in mask execute the intrinsic in convergence. Here we provide legacy mappings -// to the less-verbose routines provided in previous versions of CUDA. -#define __ballot_sync(mask, predicate) __ballot(predicate) -#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width) -#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width) -#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width) -#define __shfl_xor_sync(mask, val, laneMask, width) \ - __shfl_xor(val, laneMask, width) -#endif - -// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and -// GetCuda3DLaunchConfig: -// -// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one -// version uses heuristics without any knowledge of the device kernel, the other -// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical -// launch parameters that maximize occupancy. Currently, only the maximum -// occupancy version of GetCuda3DLaunchConfig is available. -// -// For large number of work elements, the convention is that each kernel would -// iterate through its assigned range. The return value of GetCudaLaunchConfig -// is struct CudaLaunchConfig, which contains all the information needed for the -// kernel launch, including: virtual number of threads, the number of threads -// per block and number of threads per block used inside <<< >>> of a kernel -// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing -// as CudaLaunchConfig. The only difference is the dimension. The macros -// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop. -// -/* Sample code: - -__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) { - CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { - do_your_job_here; - } +namespace tensorflow { +__host__ __device__ inline tensorflow::bfloat16 CudaLdg( + const tensorflow::bfloat16* address) { + tensorflow::bfloat16 return_value; + return_value.value = CudaLdg(reinterpret_cast<const uint16_t*>(address)); + return return_value; } -__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { - do_your_job_here; - } - } +template <typename T> +__host__ __device__ inline T ldg(const T* ptr) { + return CudaLdg(ptr); } -__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { - do_your_job_here; - } - } - } +template <typename T> +__host__ __device__ inline const T& tf_min(const T& x, const T& y) { + return x < y ? x : y; } -void MyDriverFunc(const GPUDevice &d) { - // use heuristics - CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d); - MyKernel1D <<<config.block_count, - config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...); - Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d); - MyKernel2D <<<config.block_count, - config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...); - Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d); - MyKernel3D <<<config.block_count, - config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...); - - // maximize occupancy - CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 ); - MyKernel1D <<<config.block_count, - config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...); - Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d, - MyKernel1D, 0, 0); - MyKernel2D <<<config.block_count, - config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...); - Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d, - MyKernel1D, 0, 0); - MyKernel3D <<<config.block_count, - config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...); +template <typename T> +__host__ __device__ inline const T& tf_max(const T& x, const T& y) { + return x < y ? y : x; } -// See the test for this for more example: -// -https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc - -*/ - -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - -#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \ - for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \ - i += blockDim.axis * gridDim.axis) - -#define DIV_UP(a, b) (((a) + (b)-1) / (b)) - -namespace tensorflow { - -typedef Eigen::GpuDevice GPUDevice; - -struct CudaLaunchConfig { - // Logical number of thread that works on the elements. If each logical - // thread works on exactly a single element, this is the same as the working - // element count. - int virtual_thread_count = -1; - // Number of threads per block. - int thread_per_block = -1; - // Number of blocks for Cuda kernel launch. - int block_count = -1; -}; - -// Calculate the Cuda launch config we should use for a kernel launch. -// This is assuming the kernel is quite simple and will largely be -// memory-limited. -// REQUIRES: work_element_count > 0. -inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, - const GPUDevice& d) { - CHECK_GT(work_element_count, 0); - CudaLaunchConfig config; - const int virtual_thread_count = work_element_count; - const int physical_thread_count = std::min( - d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(), - virtual_thread_count); - const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock()); - const int block_count = - std::min(DIV_UP(physical_thread_count, thread_per_block), - d.getNumCudaMultiProcessors()); - - config.virtual_thread_count = virtual_thread_count; - config.thread_per_block = thread_per_block; - config.block_count = block_count; - return config; +// Overloads of the above functions for float and double. +__host__ __device__ inline float tf_min(float x, float y) { + return fminf(x, y); } - -// Calculate the Cuda launch config we should use for a kernel launch. This -// variant takes the resource limits of func into account to maximize occupancy. -// REQUIRES: work_element_count > 0. -template <typename DeviceFunc> -inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, - const GPUDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, - int block_size_limit) { - CHECK_GT(work_element_count, 0); - CudaLaunchConfig config; - int block_count = 0; - int thread_per_block = 0; - - cudaError_t err = cudaOccupancyMaxPotentialBlockSize( - &block_count, &thread_per_block, func, dynamic_shared_memory_size, - block_size_limit); - CHECK_EQ(err, cudaSuccess); - - block_count = - std::min(block_count, DIV_UP(work_element_count, thread_per_block)); - - config.virtual_thread_count = work_element_count; - config.thread_per_block = thread_per_block; - config.block_count = block_count; - return config; +__host__ __device__ inline double tf_min(double x, double y) { + return fmin(x, y); +} +__host__ __device__ inline float tf_max(float x, float y) { + return fmaxf(x, y); +} +__host__ __device__ inline double tf_max(double x, double y) { + return fmax(x, y); } -struct Cuda2DLaunchConfig { - dim3 virtual_thread_count = dim3(0, 0, 0); - dim3 thread_per_block = dim3(0, 0, 0); - dim3 block_count = dim3(0, 0, 0); -}; - -inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, - const GPUDevice& d) { - Cuda2DLaunchConfig config; - - if (xdim <= 0 || ydim <= 0) { - return config; - } - - const int kThreadsPerBlock = 256; - int block_cols = std::min(xdim, kThreadsPerBlock); - // ok to round down here and just do more loops in the kernel - int block_rows = std::max(kThreadsPerBlock / block_cols, 1); - - const int physical_thread_count = - d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(); - - const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1); - - config.virtual_thread_count = dim3(xdim, ydim, 1); - config.thread_per_block = dim3(block_cols, block_rows, 1); - - int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks); +__device__ inline Eigen::half CudaShuffleSync(unsigned mask, Eigen::half value, + int src_lane, + int width = warpSize) { + return Eigen::half( + CudaShuffleSync(mask, static_cast<uint16>(value), src_lane, width)); +} - config.block_count = dim3( - grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1); - return config; +__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleUpSync( + unsigned mask, Eigen::half value, int delta, int width = warpSize) { + return Eigen::half( + CudaShuffleUpSync(mask, static_cast<uint16>(value), delta, width)); } -// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch. -// This variant takes the resource limits of func into account to maximize -// occupancy. -using Cuda3DLaunchConfig = Cuda2DLaunchConfig; +__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDownSync( + unsigned mask, Eigen::half value, int delta, int width = warpSize) { + return Eigen::half( + CudaShuffleDownSync(mask, static_cast<uint16>(value), delta, width)); +} -template <typename DeviceFunc> -inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( - int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, int block_size_limit) { - Cuda3DLaunchConfig config; +__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync( + unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) { + return Eigen::half( + CudaShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width)); +} - if (xdim <= 0 || ydim <= 0 || zdim <= 0) { - return config; +namespace detail { +// Overload of above function for half. Note that we don't have +// atomicCAS() for anything less than 32 bits, so we need to include the +// other 16 bits in the operation. +// +// This version is going to be very slow +// under high concurrency, since most threads will be spinning on failing +// their compare-and-swap tests. (The fact that we get false sharing on the +// neighboring fp16 makes this even worse.) If you are doing a large reduction, +// you are much better off with doing the intermediate steps in fp32 and then +// switching to fp16 as late as you can in the calculations. +// +// Note: Assumes little endian. +template <typename F> +__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) { +#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) + static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian"); +#endif + namespace half_impl = Eigen::half_impl; + intptr_t intptr = reinterpret_cast<intptr_t>(ptr); + assert(!(intptr & 0x1)); // should be 2-aligned. + if (intptr & 0x2) { + // The half is in the second part of the uint32 (upper 16 bits). + uint32* address = reinterpret_cast<uint32*>(intptr - 2); + uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short high = static_cast<unsigned short>(arg >> 16); + Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high)); + return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff); + }); + return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16)); + } else { + // The half is in the first part of the uint32 (lower 16 bits). + uint32* address = reinterpret_cast<uint32*>(intptr); + uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short low = static_cast<unsigned short>(arg & 0xffff); + Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low)); + return (arg & 0xffff0000) | static_cast<uint32>(acc.x); + }); + return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff)); } - - int dev; - cudaGetDevice(&dev); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, dev); - int xthreadlimit = deviceProp.maxThreadsDim[0]; - int ythreadlimit = deviceProp.maxThreadsDim[1]; - int zthreadlimit = deviceProp.maxThreadsDim[2]; - int xgridlimit = deviceProp.maxGridSize[0]; - int ygridlimit = deviceProp.maxGridSize[1]; - int zgridlimit = deviceProp.maxGridSize[2]; - - int block_count = 0; - int thread_per_block = 0; - cudaError_t err = cudaOccupancyMaxPotentialBlockSize( - &block_count, &thread_per_block, func, dynamic_shared_memory_size, - block_size_limit); - CHECK_EQ(err, cudaSuccess); - -#define MIN3(a, b, c) std::min((a), std::min((b), (c))) - int threadsx = MIN3(xdim, thread_per_block, xthreadlimit); - int threadsy = - MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit); - int threadsz = - MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1), - zthreadlimit); - - int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit); - int blocksy = - MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit); - int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)), - DIV_UP(zdim, threadsz), zgridlimit); -#undef MIN3 - - config.virtual_thread_count = dim3(xdim, ydim, zdim); - config.thread_per_block = dim3(threadsx, threadsy, threadsz); - config.block_count = dim3(blocksx, blocksy, blocksz); - return config; } +} // namespace detail -template <typename DeviceFunc> -inline Cuda2DLaunchConfig GetCuda2DLaunchConfig( - int xdim, int ydim, const GPUDevice& d, DeviceFunc func, - size_t dynamic_shared_memory_size, int block_size_limit) { - return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func, - dynamic_shared_memory_size, block_size_limit); +__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return a + value; }); } - -// Returns a raw reference to the current cuda stream. Required by a -// number of kernel calls (for which StreamInterface* does not work), i.e. -// CUB and certain cublas primitives. -inline const cudaStream_t& GetCudaStream(OpKernelContext* context) { - const cudaStream_t* ptr = CHECK_NOTNULL( - reinterpret_cast<const cudaStream_t*>(context->op_device_context() - ->stream() - ->implementation() - ->CudaStreamMemberHack())); - return *ptr; +__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr, + Eigen::half value) { + return detail::CudaAtomicCasHelper( + ptr, [value](Eigen::half a) { return a - value; }); } namespace cuda_helper { - template <typename IntType> __device__ IntType upper_bound(IntType* first, IntType count, IntType val) { IntType* orig = first; @@ -330,495 +164,8 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) { return first - orig; } - } // namespace cuda_helper - -template <typename T> -__device__ __host__ inline T ldg(const T* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - return __ldg(address); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline std::complex<float> ldg( - const std::complex<float>* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - float2 mem = __ldg(reinterpret_cast<const float2*>(address)); - return std::complex<float>(mem.x, mem.y); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline std::complex<double> ldg( - const std::complex<double>* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - double2 mem = __ldg(reinterpret_cast<const double2*>(address)); - return std::complex<double>(mem.x, mem.y); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline Eigen::half ldg(const Eigen::half* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - return Eigen::half_impl::raw_uint16_to_half( - __ldg(reinterpret_cast<const uint16_t*>(address))); -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline tensorflow::bfloat16 ldg( - const tensorflow::bfloat16* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - tensorflow::bfloat16 return_value; - asm volatile("ld.global.nc.u16 %0, [%1];" - : "=h"(return_value.value) - : "l"(address)); - return return_value; -#else - return *address; -#endif -} - -template <> -__device__ __host__ inline bool ldg(const bool* address) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 - return *reinterpret_cast<const bool*>( - __ldg(reinterpret_cast<const char*>(address))); -#else - return *address; -#endif -} - -// CUDA provides atomic ops, but not for all types. We provide wrappers -// for some ops and provide implementation for all reasonable types. -#define CUDA_ATOMIC_WRAPPER(op, T) \ - __device__ __forceinline__ T CudaAtomic##op(T* address, T val) - -#define USE_CUDA_ATOMIC(op, T) \ - CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } - -// For atomicAdd. -USE_CUDA_ATOMIC(Add, int32); -USE_CUDA_ATOMIC(Add, uint32); -USE_CUDA_ATOMIC(Add, uint64); -USE_CUDA_ATOMIC(Add, float); - -// For atomicMax. -USE_CUDA_ATOMIC(Max, int32); -USE_CUDA_ATOMIC(Max, uint32); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 -USE_CUDA_ATOMIC(Max, uint64); -#else -// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >= -// 350. If not satisfied, we provide a custom implementation using atomicCAS(). -CUDA_ATOMIC_WRAPPER(Max, uint64) { - uint64* address_as_ull = reinterpret_cast<uint64*>(address); - uint64 old = *address_as_ull, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, max(val, assumed)); - } while (assumed != old); - - return old; -} -#endif - -// Custom implementation of atomicAdd for double. -// This implementation is copied from CUDA manual. -CUDA_ATOMIC_WRAPPER(Add, double) { - uint64* address_as_ull = reinterpret_cast<uint64*>(address); - uint64 old = *address_as_ull, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); - - // Note: uses integer comparison to avoid hang in case of NaN - } while (assumed != old); - - return __longlong_as_double(old); -} - -// Custom implementation of atomicAdd for std::complex<float>. -// This implementation performs to atomic additions on the components. -CUDA_ATOMIC_WRAPPER(Add, std::complex<float>) { -#if defined(__CUDA_ARCH__) -#if __CUDA_ARCH__ >= 350 - float2* addr_as_float2 = reinterpret_cast<float2*>(address); - float2* val_as_float2 = reinterpret_cast<float2*>(&val); - CudaAtomicAdd(&(addr_as_float2->x), val_as_float2->x); - CudaAtomicAdd(&(addr_as_float2->y), val_as_float2->y); -#else - static_assert(sizeof(std::complex<float>) == 2 * sizeof(float), - "Unable to compile CudaAtomicAdd for complex64 because " - "sizeof(complex64) != 2*sizeof(float32)"); - float* addr_as_float = reinterpret_cast<float*>(address); - float* val_as_float = reinterpret_cast<float*>(&val); - CudaAtomicAdd(addr_as_float, *val_as_float); - CudaAtomicAdd(addr_as_float + 1, *(val_as_float + 1)); -#endif -#endif - return *address; -} - -// Custom implementation of atomicAdd for std::complex<double>. -// This implementation performs to atomic additions on the components -// using the double atomic wrapper above. -CUDA_ATOMIC_WRAPPER(Add, complex128) { -#if defined(__CUDA_ARCH__) -#if __CUDA_ARCH__ >= 350 - double2* addr_as_double2 = reinterpret_cast<double2*>(address); - double2* val_as_double2 = reinterpret_cast<double2*>(&val); - CudaAtomicAdd(&(addr_as_double2->x), val_as_double2->x); - CudaAtomicAdd(&(addr_as_double2->y), val_as_double2->y); -#else - static_assert(sizeof(std::complex<double>) == 2 * sizeof(double), - "Unable to compile CudaAtomicAdd for complex128 because " - "sizeof(complex128) != 2*sizeof(float64)"); - double* addr_as_double = reinterpret_cast<double*>(address); - double* val_as_double = reinterpret_cast<double*>(&val); - CudaAtomicAdd(addr_as_double, *val_as_double); - CudaAtomicAdd(addr_as_double + 1, *(val_as_double + 1)); -#endif -#endif - return *address; -} - -// Helper functions for CudaAtomicAdd(half*, half), below. -// -// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2() -// for a more efficient implementation, assuming that adding -0.0 -// will never harm the neighboring value. In this version, we take special -// care to guarantee the bits of the untouched value are unchanged. -inline __device__ uint32 add_to_low_half(uint32 val, float x) { - Eigen::half low_half; - low_half.x = static_cast<uint16>(val & 0xffffu); - low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x); - return (val & 0xffff0000u) | low_half.x; -} - -inline __device__ uint32 add_to_high_half(uint32 val, float x) { - Eigen::half high_half; - high_half.x = static_cast<uint16>(val >> 16); - high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x); - return (val & 0xffffu) | (high_half.x << 16); -} - -// Custom implementation of atomicAdd for half. Note that we don't have -// atomicCAS() for anything less than 32 bits, so we need to include the -// other 16 bits in the operation. -// -// Unlike the other atomic adds, this version is going to be very slow -// under high concurrency, since most threads will be spinning on failing -// their compare-and-swap tests. (The fact that we get false sharing on the -// neighboring fp16 makes this even worse.) If you are doing a large reduction, -// you are much better off with doing the intermediate steps in fp32 and then -// switching to fp16 as late as you can in the calculations. -// -// Note: Assumes little endian. -CUDA_ATOMIC_WRAPPER(Add, Eigen::half) { - float val_as_float(val); - intptr_t address_int = reinterpret_cast<intptr_t>(address); - if ((address_int & 0x2) == 0) { - // The half is in the first part of the uint32 (lower 16 bits). - uint32* address_as_uint32 = reinterpret_cast<uint32*>(address); - assert(((intptr_t)address_as_uint32 & 0x3) == 0); - uint32 old = *address_as_uint32, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_uint32, assumed, - add_to_low_half(assumed, val_as_float)); - - // Note: uses integer comparison to avoid hang in case of NaN - } while (assumed != old); - - Eigen::half ret; - ret.x = old & 0xffffu; - return ret; - } else { - // The half is in the second part of the uint32 (upper 16 bits). - uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2); - assert(((intptr_t)address_as_uint32 & 0x3) == 0); - uint32 old = *address_as_uint32, assumed; - - do { - assumed = old; - old = atomicCAS(address_as_uint32, assumed, - add_to_high_half(assumed, val_as_float)); - - // Note: uses integer comparison to avoid hang in case of NaN - } while (assumed != old); - - Eigen::half ret; - ret.x = old >> 16; - return ret; - } -} - -template <typename T> -__global__ void SetZero(const int nthreads, T* bottom_diff) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); } -} - -// For atomicSub. - -// Custom implementation for sub by just negating the value. -#define WRAPPED_ATOMIC_SUB(T) \ - CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); } - -WRAPPED_ATOMIC_SUB(uint64); -WRAPPED_ATOMIC_SUB(int32); -WRAPPED_ATOMIC_SUB(uint32); -WRAPPED_ATOMIC_SUB(Eigen::half); -WRAPPED_ATOMIC_SUB(float); -WRAPPED_ATOMIC_SUB(double); - -CUDA_ATOMIC_WRAPPER(Sub, complex64) { - const std::complex<float> Tneg(-val.real(), -val.imag()); - return CudaAtomicAdd(address, Tneg); -} - -CUDA_ATOMIC_WRAPPER(Sub, complex128) { - const std::complex<double> Tneg(-val.real(), -val.imag()); - return CudaAtomicAdd(address, Tneg); -} - -#undef WRAPPED_ATOMIC_SUB - -// For atomicMul. -CUDA_ATOMIC_WRAPPER(Mul, int32) { - int32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, val * assumed); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Mul, uint32) { - uint32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, val * assumed); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Mul, uint64) { - uint64 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, val * assumed); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Mul, float) { - int32* address_as_int = reinterpret_cast<int32*>(address); - int32 old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - __float_as_int(val * __int_as_float(assumed))); - } while (assumed != old); - return __int_as_float(old); -} - -CUDA_ATOMIC_WRAPPER(Mul, double) { - uint64* address_as_ull = reinterpret_cast<uint64*>(address); - uint64 old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val * __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); -} - -// For atomicDiv. -CUDA_ATOMIC_WRAPPER(Div, int32) { - int32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, assumed / val); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Div, uint32) { - uint32 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, assumed / val); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Div, uint64) { - uint64 old = *address, assumed; - do { - assumed = old; - old = atomicCAS(address, assumed, assumed / val); - } while (assumed != old); - return old; -} - -CUDA_ATOMIC_WRAPPER(Div, float) { - int32* address_as_int = reinterpret_cast<int32*>(address); - int32 old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - __float_as_int(__int_as_float(assumed) / val)); - } while (assumed != old); - return __int_as_float(old); -} - -CUDA_ATOMIC_WRAPPER(Div, double) { - uint64* address_as_ull = reinterpret_cast<uint64*>(address); - uint64 old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(__longlong_as_double(assumed) / val)); - } while (assumed != old); - return __longlong_as_double(old); -} - -#undef USE_CUDA_ATOMIC -#undef CUDA_ATOMIC_WRAPPER - -template <typename T> -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) { - return x > y ? y : x; -} - -template <typename T> -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) { - return x < y ? y : x; -} - -__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask, - int predicate) { - return __ballot_sync(mask, predicate); -} - -template <typename T> -__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value, - int srcLane, - int width = warpSize) { - return __shfl_sync(mask, value, srcLane, width); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, double value, - int srcLane, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_sync(mask, hi, srcLane, width); - lo = __shfl_sync(mask, lo, srcLane, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - -template <typename T> -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, T value, - int delta, - int width = warpSize) { - return __shfl_up_sync(mask, value, delta, width); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, double value, - int delta, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_up_sync(mask, hi, delta, width); - lo = __shfl_up_sync(mask, lo, delta, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - -template <typename T> -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value, - int delta, - int width = warpSize) { - return __shfl_down_sync(mask, value, delta, width); -} - -__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown( - unsigned mask, Eigen::half value, int delta, int width = warpSize) { - return Eigen::half( - __shfl_down_sync(mask, static_cast<uint16>(value), delta, width)); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask, - double value, int delta, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_down_sync(mask, hi, delta, width); - lo = __shfl_down_sync(mask, lo, delta, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - -template <typename T> -__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value, - int laneMask, - int width = warpSize) { - return __shfl_xor_sync(mask, value, laneMask, width); -} - -__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor( - unsigned mask, Eigen::half value, int laneMask, int width = warpSize) { - return Eigen::half( - __shfl_xor_sync(mask, static_cast<uint16>(value), laneMask, width)); -} - -// Variant of the (undocumented) version from the CUDA SDK, but using unsigned -// instead of float for lo and hi (which is incorrect with ftz, for example). -// A bug has been filed with NVIDIA and will be fixed in the next CUDA release. -// TODO(csigg): remove when the bug is fixed in the next CUDA release. -__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask, - double value, int laneMask, - int width = warpSize) { - unsigned lo, hi; - asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); - hi = __shfl_xor_sync(mask, hi, laneMask, width); - lo = __shfl_xor_sync(mask, lo, laneMask, width); - asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); - return value; -} - } // namespace tensorflow -#undef DIV_UP - #endif // GOOGLE_CUDA - #endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc index 6991554eff..bd4c356ea0 100644 --- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc +++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc @@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) { } } __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { if (x < 0) { // x might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { if (y < 0) { // y might overflow when testing extreme case break; } @@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) { } } __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) { - CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { if (x < 0) { // x might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { if (y < 0) { // y might overflow when testing extreme case break; } - CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { + CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { if (z < 0) { // z might overflow when testing extreme case break; } @@ -87,6 +87,44 @@ __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) { } } +__global__ void CudaShuffleGetSrcLaneTest(unsigned* failure_count) { + unsigned lane_id = CudaLaneId(); + for (int width = warpSize; width > 1; width /= 2) { + auto check_result = [&](const char* op_name, int param, unsigned actual, + unsigned expected) { + if (actual != expected) { + printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n", + op_name, param, width, lane_id, actual, expected); + CudaAtomicAdd(failure_count, 1); + } + }; + for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) { + unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width); + unsigned expect_lane = + CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width); + check_result("Shuffle", src_lane, actual_lane, expect_lane); + } + for (unsigned delta = 0; delta <= warpSize; ++delta) { + unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width); + unsigned expect_lane = + CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width); + check_result("ShuffleUp", delta, actual_lane, expect_lane); + } + for (unsigned delta = 0; delta <= warpSize; ++delta) { + unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width); + unsigned expect_lane = + CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width); + check_result("ShuffleDown", delta, actual_lane, expect_lane); + } + for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) { + unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width); + unsigned expect_lane = + CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width); + check_result("ShuffleXor", lane_lane, actual_lane, expect_lane); + } + } +} + } // namespace class CudaLaunchConfigTest : public ::testing::Test { @@ -94,7 +132,7 @@ class CudaLaunchConfigTest : public ::testing::Test { const int bufsize = 1024; int* outbuf = nullptr; Eigen::CudaStreamDevice stream; - GPUDevice d = GPUDevice(&stream); + Eigen::GpuDevice d = Eigen::GpuDevice(&stream); virtual void SetUp() { cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize); @@ -229,6 +267,16 @@ TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) { #undef TEST_LAUNCH_PARAMETER } +TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) { + unsigned* failure_count; + ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess); + *failure_count = 0; + CudaShuffleGetSrcLaneTest<<<1, 32>>>(failure_count); + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); + ASSERT_EQ(*failure_count, 0); + cudaFree(failure_count); +} + } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h new file mode 100644 index 0000000000..3ea33ee6cf --- /dev/null +++ b/tensorflow/core/util/cuda_launch_config.h @@ -0,0 +1,284 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_ +#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_ + +#if GOOGLE_CUDA + +#include <algorithm> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "cuda/include/cuda.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" + +// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and +// GetCuda3DLaunchConfig: +// +// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one +// version uses heuristics without any knowledge of the device kernel, the other +// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical +// launch parameters that maximize occupancy. Currently, only the maximum +// occupancy version of GetCuda3DLaunchConfig is available. +// +// For large number of work elements, the convention is that each kernel would +// iterate through its assigned range. The return value of GetCudaLaunchConfig +// is struct CudaLaunchConfig, which contains all the information needed for the +// kernel launch, including: virtual number of threads, the number of threads +// per block and number of threads per block used inside <<< >>> of a kernel +// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing +// as CudaLaunchConfig. The only difference is the dimension. The macros +// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop. +// +/* Sample code: + +__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) { + CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { + do_your_job_here; + } +} + +__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + do_your_job_here; + } + } +} + +__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) { + CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { + CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { + CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { + do_your_job_here; + } + } + } +} + +void MyDriverFunc(const Eigen::GpuDevice &d) { + // use heuristics + CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d); + MyKernel1D <<<config.block_count, + config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...); + Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d); + MyKernel2D <<<config.block_count, + config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...); + Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d); + MyKernel3D <<<config.block_count, + config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...); + + // maximize occupancy + CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 ); + MyKernel1D <<<config.block_count, + config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...); + Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d, + MyKernel1D, 0, 0); + MyKernel2D <<<config.block_count, + config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...); + Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d, + MyKernel1D, 0, 0); + MyKernel3D <<<config.block_count, + config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...); +} + +// See the test for this for more example: +// +https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc + +*/ + +namespace tensorflow { + +inline int DivUp(int a, int b) { return (a + b - 1) / b; } + +struct CudaLaunchConfig { + // Logical number of thread that works on the elements. If each logical + // thread works on exactly a single element, this is the same as the working + // element count. + int virtual_thread_count = -1; + // Number of threads per block. + int thread_per_block = -1; + // Number of blocks for Cuda kernel launch. + int block_count = -1; +}; + +// Calculate the Cuda launch config we should use for a kernel launch. +// This is assuming the kernel is quite simple and will largely be +// memory-limited. +// REQUIRES: work_element_count > 0. +inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, + const Eigen::GpuDevice& d) { + CHECK_GT(work_element_count, 0); + CudaLaunchConfig config; + const int virtual_thread_count = work_element_count; + const int physical_thread_count = std::min( + d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(), + virtual_thread_count); + const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock()); + const int block_count = + std::min(DivUp(physical_thread_count, thread_per_block), + d.getNumCudaMultiProcessors()); + + config.virtual_thread_count = virtual_thread_count; + config.thread_per_block = thread_per_block; + config.block_count = block_count; + return config; +} + +// Calculate the Cuda launch config we should use for a kernel launch. This +// variant takes the resource limits of func into account to maximize occupancy. +// REQUIRES: work_element_count > 0. +template <typename DeviceFunc> +inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, + const Eigen::GpuDevice& d, + DeviceFunc func, + size_t dynamic_shared_memory_size, + int block_size_limit) { + CHECK_GT(work_element_count, 0); + CudaLaunchConfig config; + int block_count = 0; + int thread_per_block = 0; + + cudaError_t err = cudaOccupancyMaxPotentialBlockSize( + &block_count, &thread_per_block, func, dynamic_shared_memory_size, + block_size_limit); + CHECK_EQ(err, cudaSuccess); + + block_count = + std::min(block_count, DivUp(work_element_count, thread_per_block)); + + config.virtual_thread_count = work_element_count; + config.thread_per_block = thread_per_block; + config.block_count = block_count; + return config; +} + +struct Cuda2DLaunchConfig { + dim3 virtual_thread_count = dim3(0, 0, 0); + dim3 thread_per_block = dim3(0, 0, 0); + dim3 block_count = dim3(0, 0, 0); +}; + +inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim, + const Eigen::GpuDevice& d) { + Cuda2DLaunchConfig config; + + if (xdim <= 0 || ydim <= 0) { + return config; + } + + const int kThreadsPerBlock = 256; + int block_cols = std::min(xdim, kThreadsPerBlock); + // ok to round down here and just do more loops in the kernel + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + + const int physical_thread_count = + d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(); + + const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1); + + config.virtual_thread_count = dim3(xdim, ydim, 1); + config.thread_per_block = dim3(block_cols, block_rows, 1); + + int grid_x = std::min(DivUp(xdim, block_cols), max_blocks); + + config.block_count = dim3( + grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1); + return config; +} + +// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch. +// This variant takes the resource limits of func into account to maximize +// occupancy. +using Cuda3DLaunchConfig = Cuda2DLaunchConfig; + +template <typename DeviceFunc> +inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( + int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size, int block_size_limit) { + Cuda3DLaunchConfig config; + + if (xdim <= 0 || ydim <= 0 || zdim <= 0) { + return config; + } + + int dev; + cudaGetDevice(&dev); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, dev); + int xthreadlimit = deviceProp.maxThreadsDim[0]; + int ythreadlimit = deviceProp.maxThreadsDim[1]; + int zthreadlimit = deviceProp.maxThreadsDim[2]; + int xgridlimit = deviceProp.maxGridSize[0]; + int ygridlimit = deviceProp.maxGridSize[1]; + int zgridlimit = deviceProp.maxGridSize[2]; + + int block_count = 0; + int thread_per_block = 0; + cudaError_t err = cudaOccupancyMaxPotentialBlockSize( + &block_count, &thread_per_block, func, dynamic_shared_memory_size, + block_size_limit); + CHECK_EQ(err, cudaSuccess); + + auto min3 = [](int a, int b, int c) { return std::min(a, std::min(b, c)); }; + + int threadsx = min3(xdim, thread_per_block, xthreadlimit); + int threadsy = + min3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit); + int threadsz = + min3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1), + zthreadlimit); + + int blocksx = min3(block_count, DivUp(xdim, threadsx), xgridlimit); + int blocksy = + min3(DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit); + int blocksz = min3(DivUp(block_count, (blocksx * blocksy)), + DivUp(zdim, threadsz), zgridlimit); + + config.virtual_thread_count = dim3(xdim, ydim, zdim); + config.thread_per_block = dim3(threadsx, threadsy, threadsz); + config.block_count = dim3(blocksx, blocksy, blocksz); + return config; +} + +template <typename DeviceFunc> +inline Cuda2DLaunchConfig GetCuda2DLaunchConfig( + int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size, int block_size_limit) { + return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func, + dynamic_shared_memory_size, block_size_limit); +} + +// Returns a raw reference to the current cuda stream. Required by a +// number of kernel calls (for which StreamInterface* does not work), i.e. +// CUB and certain cublas primitives. +inline const cudaStream_t& GetCudaStream(OpKernelContext* context) { + const cudaStream_t* ptr = CHECK_NOTNULL( + reinterpret_cast<const cudaStream_t*>(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + return *ptr; +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ |