From f0c4c6c3f3a7e6df4dbd98385ec96a72638d5031 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Jun 2017 00:55:45 -0700 Subject: In the CUDA path of depthwise_conv2d, add a fast NCHW backward filter convolution for images smaller than 16x16. PiperOrigin-RevId: 158111294 --- .../core/kernels/depthwise_conv_op_gpu.cu.cc | 254 ++++++++++++++++++--- 1 file changed, 224 insertions(+), 30 deletions(-) diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 955756958b..e4d7c3d11e 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -34,8 +34,8 @@ namespace tensorflow { using Eigen::GpuDevice; -// Returns whether depthwise convolution forward pass can be performed using the -// faster ('Small') variant of the kernel. +// Returns whether depthwise convolution forward or backward input pass can be +// performed using the faster ('Small') variant of the kernel. EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( const DepthwiseArgs args) { return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 16 && @@ -47,6 +47,18 @@ EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( (args.in_rows + 1) / 2 * args.in_cols; } +// Returns whether depthwise convolution backward filter pass can be performed +// using the faster ('Small') variant of the kernel. +EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall( + const DepthwiseArgs args, const int block_rows) { + return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 16 && + args.in_cols <= 16 && args.in_rows == args.out_rows && + args.in_cols == args.out_cols && args.pad_rows >= 0 && + args.pad_rows < args.filter_rows && args.pad_cols >= 0 && + args.pad_cols < args.filter_cols && block_rows <= args.in_rows && + args.filter_rows * args.filter_cols <= args.in_cols * block_rows; +} + // A Cuda kernel to compute the depthwise convolution forward pass // in NHWC format. template __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { + assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z)); // Holds block plus halo and filter data for blockDim.x depths. extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); @@ -1336,6 +1350,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const int batch_blocks = (in_depth + block_slices - 1) / block_slices; const int in_blocks = batch_blocks * batches; const int tensor_offset = block_rows * in_row_size; + // The accumulator has a fixed number of pixels that can be reduced by one + // warp. Pixels beyond block_pixels/4 are never written. const int accum_pixels = 32; const int accum_increment = accum_pixels * block_slices; const int accum_size = filter_pixels * accum_increment; @@ -1404,6 +1420,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( UNROLL for (int c = 0; c < filter_cols; ++c) { const T* const tile_ptr = shared_offset + shared_data; T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; + // Sum up 4 block_pixels of the same depth and write to accumulator. val += CudaShuffleDown(val, 16); val += CudaShuffleDown(val, 8); if (!(thread_idx & 24) /* i.e. 'lane_idx < 8' */) { @@ -1421,12 +1438,13 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const T* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { - const int filter_idx = i / 32; + const int filter_idx = i / accum_pixels; const int filter_pix = filter_idx / block_slices; const int filter_depth = filter_idx % block_slices + start_depth; const int filter_offset = filter_pix * in_depth + filter_depth; if (filter_depth < in_depth) { T val = accum_data[i]; + // Sum up the 32 pixels of the same depth from the accumulator. val += CudaShuffleDown(val, 16); val += CudaShuffleDown(val, 8); val += CudaShuffleDown(val, 4); @@ -1540,46 +1558,222 @@ __global__ void __launch_bounds__(640, 2) } } +// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in +// NCHW format, tailored for small images up to 16x16. Stride and depth +// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if +// CanLaunchDepthwiseConv2dGPUSmall(args) returns true. +// Tiles of the input tensor are loaded into shared memory before performing the +// convolution. Per iteration and filter element, each thread first performs +// a partial convolution for two elements, one each in the lower and upper half +// of a tile. The intermediate result of 4 consecutive columns are then +// accumulated and written to shared memory. Finally, the values in shared +// memory are warp-accumulated (in chunks of 32 elements) and summed up in +// global memory using atomics. +template +__global__ +__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( + const DepthwiseArgs args, const T* output, const T* input, T* filter) { + assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x)); + // Holds block plus halo and filter data for blockDim.z depths. + extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; + T* const shared_data = reinterpret_cast(shared_memory); + + const int batches = args.batch; + const int in_rows = args.in_rows; + const int in_cols = args.in_cols; + const int in_depth = args.in_depth; + const int filter_rows = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int pad_rows = args.pad_rows; + const int pad_cols = args.pad_cols; + + // Fixed blockDim.z, corresponding to Pascal's global load granularity of 32B. + const int block_rows = blockDim.y; + const int block_slices = 8; + + // These values are the same for all threads and could + // be precomputed on the CPU. + const int block_pixels = in_cols * block_rows; + const int block_size = block_pixels * block_slices; + const int in_pixels = in_cols * in_rows; + const int in_increment = in_cols - 1; + const int filter_pixels = filter_rows * filter_cols; + const int tile_cols = in_cols + filter_cols - 1; + const int tile_rows = 2 * block_rows + filter_rows - 1; + const int tile_pixels = tile_cols * tile_rows; + const int tile_size = tile_pixels * block_slices; + const int tile_offset = block_rows * tile_cols; + const int pad_offset = pad_rows * tile_cols + pad_cols; + const int in_slices = in_depth * batches; + const int in_blocks = (in_slices + block_slices - 1) / block_slices; + // The accumulator has a fixed number of pixels that can be reduced by one + // warp. Pixels beyond block_pixels/4 are never written. + const int accum_pixels = 32; + const int accum_increment = accum_pixels * block_slices; + const int accum_size = filter_pixels * accum_increment; + + const int thread_col = threadIdx.x; + const int thread_row = threadIdx.y; + const int thread_depth = threadIdx.z; + + // Position in block. + const int thread_pix = thread_row * in_cols + thread_col; + const int thread_idx = thread_depth * block_pixels + thread_pix; + + // Initialize tile, in particular the padding and accumulator. + for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { + shared_data[i] = T(0); + } + __syncthreads(); + + // Position in tensors. + const int tensor_idx = thread_depth * in_pixels + thread_pix; + + // Position in (padded) shared memory. + const int data_pix = thread_row * tile_cols + thread_col; + const int data_idx = thread_depth * tile_pixels + data_pix; + + // Position in shared memory, offset by pad_rows / pad_cols. + const int tile_idx = data_idx + pad_offset; + + // Position in accumulator (1 per 4 threads, depth major). + const int accum_pix = thread_pix / 4; + const int accum_idx = thread_depth * accum_pixels + accum_pix; + + const int max_slice = in_slices - thread_depth; + const int accum_offset = tile_size + accum_idx; + const bool skip_second = block_rows + thread_row >= in_rows; + + for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { + const int slice = b * block_slices; + + const int inout_offset = slice * in_pixels + tensor_idx; + const bool slice_in_range = slice < max_slice; + + if (slice_in_range) { + const T* const in_ptr = inout_offset + input; + T* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = ldg(in_ptr); + if (!skip_second) { + tile_ptr[tile_offset] = ldg(block_pixels + in_ptr); + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + + if (slice_in_range) { + const T* const out_ptr = inout_offset + output; + const T out1 = ldg(out_ptr); + const T out2 = skip_second ? T(0) : ldg(block_pixels + out_ptr); + int shared_offset = data_idx; + T* accum_ptr = accum_offset + shared_data; + UNROLL for (int r = 0; r < filter_rows; ++r) { + UNROLL for (int c = 0; c < filter_cols; ++c) { + const T* const tile_ptr = shared_offset + shared_data; + T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; + // Sum up 4 block_pixels of the same depth and write to accumulator. + val += CudaShuffleDown(val, 2); + val += CudaShuffleDown(val, 1); + if (!(thread_idx & 3)) { + *accum_ptr = val; + } + ++shared_offset; + accum_ptr += accum_increment; + } + shared_offset += in_increment; + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + + const T* const accum_data = tile_size + shared_data; + for (int i = thread_idx; i < accum_size; i += block_size) { + const int filter_idx = i / accum_pixels; + const int filter_pix = filter_idx / block_slices; + const int filter_depth = (slice + filter_idx % block_slices) % in_depth; + const int filter_offset = filter_pix * in_depth + filter_depth; + if (filter_depth < in_depth) { + T val = accum_data[i]; + // Sum up 32 pixels of the same depth from the accumulator. + val += CudaShuffleDown(val, 16); + val += CudaShuffleDown(val, 8); + val += CudaShuffleDown(val, 4); + val += CudaShuffleDown(val, 2); + val += CudaShuffleDown(val, 1); + if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) { + CudaAtomicAdd(filter_offset + filter, val); + } + } + } + } +} + +template +void LaunchDepthwiseConv2dBackpropFilterGPUSmall( + const GpuDevice& d, const DepthwiseArgs args, int block_rows, + int shared_memory_size, const T* out_backprop, const T* input, + T* filter_backprop, TensorFormat data_format) { + const int block_slices = 8; + const int num_out_backprop = + args.batch * args.out_rows * args.out_cols * args.out_depth; + if (data_format == FORMAT_NHWC) { + dim3 block_dim = dim3(block_slices, args.in_cols, block_rows); + CudaLaunchConfig config = GetCudaLaunchConfig( + num_out_backprop, d, + DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall, + shared_memory_size, block_dim.x * block_dim.y * block_dim.z); + DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall + <<>>( + args, out_backprop, input, filter_backprop); + } else if (data_format == FORMAT_NCHW) { + dim3 block_dim = dim3(args.in_cols, block_rows, block_slices); + CudaLaunchConfig config = GetCudaLaunchConfig( + num_out_backprop, d, + DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall, + shared_memory_size, block_dim.x * block_dim.y * block_dim.z); + DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall + <<>>( + args, out_backprop, input, filter_backprop); + } else { + assert(false && "Incorrect data format"); + } +} + template bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - if (data_format != FORMAT_NHWC || args.depth_multiplier != 1 || - args.stride != 1 || args.in_rows > 16 || args.in_cols > 16 || - args.in_rows != args.out_rows || args.in_cols != args.out_cols || - args.pad_rows < 0 || args.pad_rows >= args.filter_rows || - args.pad_cols < 0 || args.pad_cols >= args.filter_cols) { - return false; - } - + // args.in_cols * blocks_rows (block_pixels) must be multiple of 4. const int lookup_table[] = {0, 3, 1, 3}; const int rows_mask = lookup_table[args.in_cols & 3]; const int block_rows = (args.in_rows + 1) / 2 + rows_mask & ~rows_mask; + if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) { + return false; + } + + const int block_slices = 8; const int tile_cols = args.in_cols + args.filter_cols - 1; const int tile_rows = block_rows * 2 + args.filter_rows - 1; const int tile_pixels = tile_rows * tile_cols; const int accum_size = args.filter_rows * args.filter_cols * 32; - dim3 block_dim = dim3(8, args.in_cols, block_rows); const int shared_memory_size = - block_dim.x * (tile_pixels + accum_size) * sizeof(T); - - if (block_rows > args.in_rows || - args.filter_rows * args.filter_cols > args.in_cols * block_rows || - shared_memory_size > d.sharedMemPerBlock()) { + block_slices * (tile_pixels + accum_size) * sizeof(T); + if (shared_memory_size > d.sharedMemPerBlock()) { return false; } - const int num_out_backprop = - args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = GetCudaLaunchConfig( - num_out_backprop, d, - DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall, - shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall - <<>>( - args, out_backprop, input, filter_backprop); + LaunchDepthwiseConv2dBackpropFilterGPUSmall( + d, args, block_rows, shared_memory_size, out_backprop, input, + filter_backprop, data_format); return true; } -- cgit v1.2.3