aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-06 00:55:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-06 00:59:17 -0700
commitf0c4c6c3f3a7e6df4dbd98385ec96a72638d5031 (patch)
treef8dc3369b9cb14246547d5a6e27977ce78c14846
parent232e9d86d81ce4476fc9ea674b9078dd0794f92b (diff)
In the CUDA path of depthwise_conv2d, add a fast NCHW backward filter convolution for images smaller than 16x16.
PiperOrigin-RevId: 158111294
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc254
1 files changed, 224 insertions, 30 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 955756958b..e4d7c3d11e 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -34,8 +34,8 @@ namespace tensorflow {
using Eigen::GpuDevice;
-// Returns whether depthwise convolution forward pass can be performed using the
-// faster ('Small') variant of the kernel.
+// Returns whether depthwise convolution forward or backward input pass can be
+// performed using the faster ('Small') variant of the kernel.
EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall(
const DepthwiseArgs args) {
return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 16 &&
@@ -47,6 +47,18 @@ EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall(
(args.in_rows + 1) / 2 * args.in_cols;
}
+// Returns whether depthwise convolution backward filter pass can be performed
+// using the faster ('Small') variant of the kernel.
+EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ const DepthwiseArgs args, const int block_rows) {
+ return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 16 &&
+ args.in_cols <= 16 && args.in_rows == args.out_rows &&
+ args.in_cols == args.out_cols && args.pad_rows >= 0 &&
+ args.pad_rows < args.filter_rows && args.pad_cols >= 0 &&
+ args.pad_cols < args.filter_cols && block_rows <= args.in_rows &&
+ args.filter_rows * args.filter_cols <= args.in_cols * block_rows;
+}
+
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
@@ -1288,8 +1300,9 @@ __global__ void __launch_bounds__(640, 2)
}
// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
-// NCHW format, tailored for small images up to 16x16. Stride and depth
-// multiplier must be 1. Padding must be 'SAME'.
+// NHWC format, tailored for small images up to 16x16. Stride and depth
+// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
+// CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
// Tiles of the input tensor are loaded into shared memory before performing the
// convolution. Per iteration and filter element, each thread first performs
// a partial convolution for two elements, one each in the lower and upper half
@@ -1301,6 +1314,7 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
+ assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z));
// Holds block plus halo and filter data for blockDim.x depths.
extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
T* const shared_data = reinterpret_cast<T*>(shared_memory);
@@ -1336,6 +1350,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const int batch_blocks = (in_depth + block_slices - 1) / block_slices;
const int in_blocks = batch_blocks * batches;
const int tensor_offset = block_rows * in_row_size;
+ // The accumulator has a fixed number of pixels that can be reduced by one
+ // warp. Pixels beyond block_pixels/4 are never written.
const int accum_pixels = 32;
const int accum_increment = accum_pixels * block_slices;
const int accum_size = filter_pixels * accum_increment;
@@ -1404,6 +1420,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
UNROLL for (int c = 0; c < filter_cols; ++c) {
const T* const tile_ptr = shared_offset + shared_data;
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
+ // Sum up 4 block_pixels of the same depth and write to accumulator.
val += CudaShuffleDown(val, 16);
val += CudaShuffleDown(val, 8);
if (!(thread_idx & 24) /* i.e. 'lane_idx < 8' */) {
@@ -1421,12 +1438,13 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const T* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
- const int filter_idx = i / 32;
+ const int filter_idx = i / accum_pixels;
const int filter_pix = filter_idx / block_slices;
const int filter_depth = filter_idx % block_slices + start_depth;
const int filter_offset = filter_pix * in_depth + filter_depth;
if (filter_depth < in_depth) {
T val = accum_data[i];
+ // Sum up the 32 pixels of the same depth from the accumulator.
val += CudaShuffleDown(val, 16);
val += CudaShuffleDown(val, 8);
val += CudaShuffleDown(val, 4);
@@ -1540,46 +1558,222 @@ __global__ void __launch_bounds__(640, 2)
}
}
+// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
+// NCHW format, tailored for small images up to 16x16. Stride and depth
+// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
+// CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
+// Tiles of the input tensor are loaded into shared memory before performing the
+// convolution. Per iteration and filter element, each thread first performs
+// a partial convolution for two elements, one each in the lower and upper half
+// of a tile. The intermediate result of 4 consecutive columns are then
+// accumulated and written to shared memory. Finally, the values in shared
+// memory are warp-accumulated (in chunks of 32 elements) and summed up in
+// global memory using atomics.
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+__global__
+__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
+ const DepthwiseArgs args, const T* output, const T* input, T* filter) {
+ assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x));
+ // Holds block plus halo and filter data for blockDim.z depths.
+ extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
+ T* const shared_data = reinterpret_cast<T*>(shared_memory);
+
+ const int batches = args.batch;
+ const int in_rows = args.in_rows;
+ const int in_cols = args.in_cols;
+ const int in_depth = args.in_depth;
+ const int filter_rows =
+ kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
+ const int filter_cols =
+ kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
+ const int pad_rows = args.pad_rows;
+ const int pad_cols = args.pad_cols;
+
+ // Fixed blockDim.z, corresponding to Pascal's global load granularity of 32B.
+ const int block_rows = blockDim.y;
+ const int block_slices = 8;
+
+ // These values are the same for all threads and could
+ // be precomputed on the CPU.
+ const int block_pixels = in_cols * block_rows;
+ const int block_size = block_pixels * block_slices;
+ const int in_pixels = in_cols * in_rows;
+ const int in_increment = in_cols - 1;
+ const int filter_pixels = filter_rows * filter_cols;
+ const int tile_cols = in_cols + filter_cols - 1;
+ const int tile_rows = 2 * block_rows + filter_rows - 1;
+ const int tile_pixels = tile_cols * tile_rows;
+ const int tile_size = tile_pixels * block_slices;
+ const int tile_offset = block_rows * tile_cols;
+ const int pad_offset = pad_rows * tile_cols + pad_cols;
+ const int in_slices = in_depth * batches;
+ const int in_blocks = (in_slices + block_slices - 1) / block_slices;
+ // The accumulator has a fixed number of pixels that can be reduced by one
+ // warp. Pixels beyond block_pixels/4 are never written.
+ const int accum_pixels = 32;
+ const int accum_increment = accum_pixels * block_slices;
+ const int accum_size = filter_pixels * accum_increment;
+
+ const int thread_col = threadIdx.x;
+ const int thread_row = threadIdx.y;
+ const int thread_depth = threadIdx.z;
+
+ // Position in block.
+ const int thread_pix = thread_row * in_cols + thread_col;
+ const int thread_idx = thread_depth * block_pixels + thread_pix;
+
+ // Initialize tile, in particular the padding and accumulator.
+ for (int i = thread_idx; i < tile_size + accum_size; i += block_size) {
+ shared_data[i] = T(0);
+ }
+ __syncthreads();
+
+ // Position in tensors.
+ const int tensor_idx = thread_depth * in_pixels + thread_pix;
+
+ // Position in (padded) shared memory.
+ const int data_pix = thread_row * tile_cols + thread_col;
+ const int data_idx = thread_depth * tile_pixels + data_pix;
+
+ // Position in shared memory, offset by pad_rows / pad_cols.
+ const int tile_idx = data_idx + pad_offset;
+
+ // Position in accumulator (1 per 4 threads, depth major).
+ const int accum_pix = thread_pix / 4;
+ const int accum_idx = thread_depth * accum_pixels + accum_pix;
+
+ const int max_slice = in_slices - thread_depth;
+ const int accum_offset = tile_size + accum_idx;
+ const bool skip_second = block_rows + thread_row >= in_rows;
+
+ for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
+ const int slice = b * block_slices;
+
+ const int inout_offset = slice * in_pixels + tensor_idx;
+ const bool slice_in_range = slice < max_slice;
+
+ if (slice_in_range) {
+ const T* const in_ptr = inout_offset + input;
+ T* const tile_ptr = tile_idx + shared_data;
+ tile_ptr[0] = ldg(in_ptr);
+ if (!skip_second) {
+ tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
+ }
+ }
+
+ // Note: the condition to reach this is uniform across the entire block.
+ __syncthreads();
+
+ if (slice_in_range) {
+ const T* const out_ptr = inout_offset + output;
+ const T out1 = ldg(out_ptr);
+ const T out2 = skip_second ? T(0) : ldg(block_pixels + out_ptr);
+ int shared_offset = data_idx;
+ T* accum_ptr = accum_offset + shared_data;
+ UNROLL for (int r = 0; r < filter_rows; ++r) {
+ UNROLL for (int c = 0; c < filter_cols; ++c) {
+ const T* const tile_ptr = shared_offset + shared_data;
+ T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
+ // Sum up 4 block_pixels of the same depth and write to accumulator.
+ val += CudaShuffleDown(val, 2);
+ val += CudaShuffleDown(val, 1);
+ if (!(thread_idx & 3)) {
+ *accum_ptr = val;
+ }
+ ++shared_offset;
+ accum_ptr += accum_increment;
+ }
+ shared_offset += in_increment;
+ }
+ }
+
+ // Note: the condition to reach this is uniform across the entire block.
+ __syncthreads();
+
+ const T* const accum_data = tile_size + shared_data;
+ for (int i = thread_idx; i < accum_size; i += block_size) {
+ const int filter_idx = i / accum_pixels;
+ const int filter_pix = filter_idx / block_slices;
+ const int filter_depth = (slice + filter_idx % block_slices) % in_depth;
+ const int filter_offset = filter_pix * in_depth + filter_depth;
+ if (filter_depth < in_depth) {
+ T val = accum_data[i];
+ // Sum up 32 pixels of the same depth from the accumulator.
+ val += CudaShuffleDown(val, 16);
+ val += CudaShuffleDown(val, 8);
+ val += CudaShuffleDown(val, 4);
+ val += CudaShuffleDown(val, 2);
+ val += CudaShuffleDown(val, 1);
+ if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) {
+ CudaAtomicAdd(filter_offset + filter, val);
+ }
+ }
+ }
+ }
+}
+
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+void LaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ const GpuDevice& d, const DepthwiseArgs args, int block_rows,
+ int shared_memory_size, const T* out_backprop, const T* input,
+ T* filter_backprop, TensorFormat data_format) {
+ const int block_slices = 8;
+ const int num_out_backprop =
+ args.batch * args.out_rows * args.out_cols * args.out_depth;
+ if (data_format == FORMAT_NHWC) {
+ dim3 block_dim = dim3(block_slices, args.in_cols, block_rows);
+ CudaLaunchConfig config = GetCudaLaunchConfig(
+ num_out_backprop, d,
+ DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight>,
+ shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
+ DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight>
+ <<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
+ args, out_backprop, input, filter_backprop);
+ } else if (data_format == FORMAT_NCHW) {
+ dim3 block_dim = dim3(args.in_cols, block_rows, block_slices);
+ CudaLaunchConfig config = GetCudaLaunchConfig(
+ num_out_backprop, d,
+ DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight>,
+ shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
+ DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight>
+ <<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
+ args, out_backprop, input, filter_backprop);
+ } else {
+ assert(false && "Incorrect data format");
+ }
+}
+
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop,
const T* input, T* filter_backprop, TensorFormat data_format) {
- if (data_format != FORMAT_NHWC || args.depth_multiplier != 1 ||
- args.stride != 1 || args.in_rows > 16 || args.in_cols > 16 ||
- args.in_rows != args.out_rows || args.in_cols != args.out_cols ||
- args.pad_rows < 0 || args.pad_rows >= args.filter_rows ||
- args.pad_cols < 0 || args.pad_cols >= args.filter_cols) {
- return false;
- }
-
+ // args.in_cols * blocks_rows (block_pixels) must be multiple of 4.
const int lookup_table[] = {0, 3, 1, 3};
const int rows_mask = lookup_table[args.in_cols & 3];
const int block_rows = (args.in_rows + 1) / 2 + rows_mask & ~rows_mask;
+ if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) {
+ return false;
+ }
+
+ const int block_slices = 8;
const int tile_cols = args.in_cols + args.filter_cols - 1;
const int tile_rows = block_rows * 2 + args.filter_rows - 1;
const int tile_pixels = tile_rows * tile_cols;
const int accum_size = args.filter_rows * args.filter_cols * 32;
- dim3 block_dim = dim3(8, args.in_cols, block_rows);
const int shared_memory_size =
- block_dim.x * (tile_pixels + accum_size) * sizeof(T);
-
- if (block_rows > args.in_rows ||
- args.filter_rows * args.filter_cols > args.in_cols * block_rows ||
- shared_memory_size > d.sharedMemPerBlock()) {
+ block_slices * (tile_pixels + accum_size) * sizeof(T);
+ if (shared_memory_size > d.sharedMemPerBlock()) {
return false;
}
- const int num_out_backprop =
- args.batch * args.out_rows * args.out_cols * args.out_depth;
- CudaLaunchConfig config = GetCudaLaunchConfig(
- num_out_backprop, d,
- DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>,
- shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
- DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>
- <<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
- args, out_backprop, input, filter_backprop);
+ LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight>(
+ d, args, block_rows, shared_memory_size, out_backprop, input,
+ filter_backprop, data_format);
return true;
}