aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc30
1 files changed, 22 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index fcfcd188d2..ecfe51d599 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "external/cub_archive/cub/util_ptx.cuh"
#if !defined(_MSC_VER)
#define UNROLL _Pragma("unroll")
@@ -1015,6 +1016,21 @@ __global__ void __launch_bounds__(640, 2)
}
}
+// Device function to compute sub-warp sum reduction for a power-of-two group of
+// neighboring threads.
+template<int kWidth, typename T>
+__device__ __forceinline__ T WarpSumReduce(T val) {
+ // support only power-of-two widths.
+ assert(__popc(kWidth) == 1);
+ int sub_warp = cub::LaneId() / kWidth;
+ int zeros = sub_warp * kWidth;
+ unsigned mask = ((1UL << kWidth) - 1) << zeros;
+ for (int delta = kWidth / 2; delta > 0; delta /= 2) {
+ val += CudaShuffleXor(mask, val, delta);
+ }
+ return val;
+}
+
// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
// NHWC format, tailored for small images up to 32x32. Stride and depth
// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
@@ -1127,6 +1143,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
+ unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
if (depth_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1140,7 +1157,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
- val += CudaShuffleDown(val, delta);
+ val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
@@ -1164,9 +1181,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
if (filter_depth < in_depth) {
T val = accum_data[i];
// Warp-accumulate the pixels of the same depth from the accumulator.
- for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
- val += CudaShuffleDown(val, delta);
- }
+ val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
@@ -1382,6 +1397,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
+ unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
if (slice_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1395,7 +1411,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
- val += CudaShuffleDown(val, delta);
+ val += CudaShuffleDown(active_threads, val, delta);
}
if (!(thread_idx & 32 / kBlockSlices - 1)) {
*accum_ptr = val;
@@ -1419,9 +1435,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
if (filter_depth < in_depth) {
T val = accum_data[i];
// Warp-accumulate pixels of the same depth from the accumulator.
- for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
- val += CudaShuffleDown(val, delta);
- }
+ val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}