aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-15 22:06:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-15 22:10:03 -0700
commit7fffdb236ecaf7a2f50f3363e947b19e2a5a327a (patch)
tree50bdb514cc88beb9745db4ae4a8a9b3f8b287da5
parent8bb8099697c3104ef00ae90027f1d17e3b6e992b (diff)
Automated g4 rollback of changelist 158990733
PiperOrigin-RevId: 159194246
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc541
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py19
2 files changed, 342 insertions, 218 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 319dbb68e6..038c596207 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -38,8 +38,8 @@ using Eigen::GpuDevice;
// performed using the faster ('Small') variant of the kernel.
EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall(
const DepthwiseArgs args) {
- return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 16 &&
- args.in_cols <= 16 && args.in_rows == args.out_rows &&
+ return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 &&
+ args.in_cols <= 32 && args.in_rows == args.out_rows &&
args.in_cols == args.out_cols && args.pad_rows >= 0 &&
args.pad_rows < args.filter_rows && args.pad_cols >= 0 &&
args.pad_cols < args.filter_cols &&
@@ -51,8 +51,8 @@ EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall(
// using the faster ('Small') variant of the kernel.
EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(
const DepthwiseArgs args, const int block_rows) {
- return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 16 &&
- args.in_cols <= 16 && args.in_rows == args.out_rows &&
+ return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 &&
+ args.in_cols <= 32 && args.in_rows == args.out_rows &&
args.in_cols == args.out_cols && args.pad_rows >= 0 &&
args.pad_rows < args.filter_rows && args.pad_cols >= 0 &&
args.pad_cols < args.filter_cols && block_rows <= args.in_rows &&
@@ -142,14 +142,14 @@ __global__ void __launch_bounds__(1024, 2)
}
// CUDA kernel to compute the depthwise convolution forward pass in NHWC format,
-// tailored for small images up to 16x16. Stride and depth multiplier must be 1.
+// tailored for small images up to 32x32. Stride and depth multiplier must be 1.
// Padding must be 'SAME', which allows to reuse the index computation. Only
// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
// Tiles of the input and filter tensors are loaded into shared memory before
// performing the convolution. Each thread handles two elements per iteration,
// one each in the lower and upper half of a tile.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- bool kKnownEvenRows>
+ int kBlockSlices, bool kKnownEvenRows>
__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
@@ -168,25 +168,23 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
- // Fixed blockDim.x, corresponding to Pascal's global load granularity of 32B.
- const int block_slices = 8;
const int block_rows = blockDim.z;
// These values are the same for all threads and could
// be precomputed on the CPU.
- const int block_size = block_rows * in_cols * block_slices;
+ const int block_size = block_rows * in_cols * kBlockSlices;
const int in_row_size = in_cols * in_depth;
const int in_size = in_rows * in_row_size;
- const int in_increment = (in_cols - 1) * block_slices;
+ const int in_increment = (in_cols - 1) * kBlockSlices;
const int filter_pixels = filter_rows * filter_cols;
const int tile_cols = in_cols + filter_cols - 1;
const int even_rows = kKnownEvenRows || (1 & ~in_rows);
const int tile_rows = in_rows + filter_rows - even_rows;
- const int tile_row_size = tile_cols * block_slices;
+ const int tile_row_size = tile_cols * kBlockSlices;
const int tile_size = tile_rows * tile_row_size;
const int tile_offset = block_rows * tile_row_size;
const int pad_offset = pad_rows * tile_cols + pad_cols;
- const int batch_blocks = (in_depth + block_slices - 1) / block_slices;
+ const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices;
const int in_blocks = batch_blocks * batches;
const int tensor_offset =
kKnownEvenRows ? in_size / 2 : block_rows * in_row_size;
@@ -197,7 +195,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
// Position in block.
const int thread_pix = thread_row * in_cols + thread_col;
- const int thread_idx = thread_pix * block_slices + thread_depth;
+ const int thread_idx = thread_pix * kBlockSlices + thread_depth;
// Initialize tile, in particular the padding.
for (int i = thread_idx; i < tile_size; i += block_size) {
@@ -210,11 +208,11 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
// Position in (padded) shared memory.
const int data_pix = thread_row * tile_cols + thread_col;
- const int data_idx = data_pix * block_slices + thread_depth;
+ const int data_idx = data_pix * kBlockSlices + thread_depth;
// Position in shared memory, offset by pad_rows / pad_cols.
const int tile_pix = data_pix + pad_offset;
- const int tile_idx = tile_pix * block_slices + thread_depth;
+ const int tile_idx = tile_pix * kBlockSlices + thread_depth;
const int max_depth = in_depth - thread_depth;
const int filter_write_offset =
@@ -227,7 +225,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
const int batch = b / batch_blocks;
const int stack = b - batch * batch_blocks;
- const int start_depth = stack * block_slices;
+ const int start_depth = stack * kBlockSlices;
const int filter_offset = tensor_idx + start_depth;
const int inout_offset = batch * in_size + filter_offset;
const bool depth_in_range = start_depth < max_depth;
@@ -259,8 +257,8 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
const T* const tile_ptr = shared_offset + shared_data;
sum1 += filter_value * tile_ptr[0];
sum2 += filter_value * tile_ptr[tile_offset];
- shared_offset += block_slices;
- filter_ptr += block_slices;
+ shared_offset += kBlockSlices;
+ filter_ptr += kBlockSlices;
}
shared_offset += in_increment;
}
@@ -404,14 +402,14 @@ __global__ void __launch_bounds__(1024, 2)
}
// CUDA kernel to compute the depthwise convolution forward pass in NCHW format,
-// tailored for small images up to 16x16. Stride and depth multiplier must be 1.
+// tailored for small images up to 32x32. Stride and depth multiplier must be 1.
// Padding must be 'SAME', which allows to reuse the index computation. Only
// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
// Tiles of the input and filter tensors are loaded into shared memory before
// performing the convolution. Each thread handles two elements per iteration,
// one each in the lower and upper half of a tile.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- bool kKnownEvenRows>
+ int kBlockSlices, bool kKnownEvenRows>
__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
@@ -432,12 +430,11 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
// Fixed blockDim.z, tailored for maximum grid size for images of size 16x16.
const int block_rows = blockDim.y;
- const int block_slices = 8;
// These values are the same for all threads and could
// be precomputed on the CPU.
const int block_pixels = in_cols * block_rows;
- const int block_size = block_pixels * block_slices;
+ const int block_size = block_pixels * kBlockSlices;
const int in_pixels = in_cols * in_rows;
const int in_increment = in_cols - 1;
const int filter_pixels = filter_rows * filter_cols;
@@ -445,11 +442,11 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
const int even_rows = kKnownEvenRows || (1 & ~in_rows);
const int tile_rows = in_rows + filter_rows - even_rows;
const int tile_pixels = tile_cols * tile_rows;
- const int tile_size = tile_pixels * block_slices;
+ const int tile_size = tile_pixels * kBlockSlices;
const int tile_offset = block_rows * tile_cols;
const int pad_offset = pad_rows * tile_cols + pad_cols;
const int in_slices = in_depth * batches;
- const int in_blocks = (in_slices + block_slices - 1) / block_slices;
+ const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices;
const int thread_col = threadIdx.x;
const int thread_row = threadIdx.y;
@@ -476,8 +473,8 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
const int tile_idx = data_idx + pad_offset;
// Filter is always in HWCK format, irrespective of the input/output format.
- const int filter_pix = thread_idx / block_slices;
- const int filter_depth = thread_idx % block_slices;
+ const int filter_pix = thread_idx / kBlockSlices;
+ const int filter_depth = thread_idx % kBlockSlices;
const int filter_idx = filter_pix * in_depth;
const int max_slice = in_slices - thread_depth;
@@ -488,7 +485,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
!kKnownEvenRows && thread_row + (in_rows & 1) == block_rows;
for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
- const int slice = b * block_slices;
+ const int slice = b * kBlockSlices;
const int inout_offset = slice * in_pixels + tensor_idx;
const bool slice_in_range = slice < max_slice;
@@ -522,7 +519,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
sum1 += filter_value * tile_ptr[0];
sum2 += filter_value * tile_ptr[tile_offset];
++shared_offset;
- filter_ptr += block_slices;
+ filter_ptr += kBlockSlices;
}
shared_offset += in_increment;
}
@@ -539,42 +536,43 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- bool kKnownEvenRows>
+ int kBlockSlices, bool kKnownEvenRows>
void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args,
const T* input, const T* filter, T* output,
TensorFormat data_format) {
const int block_rows = (args.in_rows + 1) / 2;
- const int block_slices = 8;
const int tile_cols = args.in_cols + args.filter_cols - 1;
const int tile_rows = block_rows * 2 + args.filter_rows - 1;
const int tile_pixels = tile_rows * tile_cols;
const int filter_pixels = args.filter_rows * args.filter_cols;
const int shared_memory_size =
- block_slices * (tile_pixels + filter_pixels) * sizeof(T);
+ kBlockSlices * (tile_pixels + filter_pixels) * sizeof(T);
const int num_outputs =
args.batch * args.out_rows * args.out_cols * args.out_depth;
if (data_format == FORMAT_NHWC) {
- dim3 block_dim = dim3(block_slices, args.in_cols, block_rows);
+ dim3 block_dim = dim3(kBlockSlices, args.in_cols, block_rows);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_outputs, d,
DepthwiseConv2dGPUKernelNHWCSmall<T, kKnownFilterWidth,
- kKnownFilterHeight, kKnownEvenRows>,
+ kKnownFilterHeight, kBlockSlices,
+ kKnownEvenRows>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dGPUKernelNHWCSmall<T, kKnownFilterWidth, kKnownFilterHeight,
- kKnownEvenRows>
+ kBlockSlices, kKnownEvenRows>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, input, filter, output);
} else if (data_format == FORMAT_NCHW) {
- dim3 block_dim = dim3(args.in_cols, block_rows, block_slices);
+ dim3 block_dim = dim3(args.in_cols, block_rows, kBlockSlices);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_outputs, d,
DepthwiseConv2dGPUKernelNCHWSmall<T, kKnownFilterWidth,
- kKnownFilterHeight, kKnownEvenRows>,
+ kKnownFilterHeight, kBlockSlices,
+ kKnownEvenRows>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dGPUKernelNCHWSmall<T, kKnownFilterWidth, kKnownFilterHeight,
- kKnownEvenRows>
+ kBlockSlices, kKnownEvenRows>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, input, filter, output);
} else {
@@ -582,17 +580,37 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args,
}
}
-template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
+ int kBlockSlices>
void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args,
const T* input, const T* filter, T* output,
TensorFormat data_format) {
if (args.in_rows & 1) {
LaunchDepthwiseConv2dGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight,
- /*kKnownEvenRows=*/false>(
- d, args, input, filter, output, data_format);
+ kBlockSlices, false>(d, args, input, filter,
+ output, data_format);
} else {
LaunchDepthwiseConv2dGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight,
- /*kKnownEvenRows=*/true>(
+ kBlockSlices, true>(d, args, input, filter,
+ output, data_format);
+ }
+}
+
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args,
+ const T* input, const T* filter, T* output,
+ TensorFormat data_format) {
+ // Maximize (power of two) kBlockSlices while keeping a block within 1024
+ // threads (2 pixels per thread).
+ const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols;
+ if (block_pixels > 256) {
+ LaunchDepthwiseConv2dGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight, 2>(
+ d, args, input, filter, output, data_format);
+ } else if (block_pixels > 128) {
+ LaunchDepthwiseConv2dGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight, 4>(
+ d, args, input, filter, output, data_format);
+ } else {
+ LaunchDepthwiseConv2dGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight, 8>(
d, args, input, filter, output, data_format);
}
}
@@ -602,11 +620,6 @@ template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
const T* input, const T* filter, T* output,
TensorFormat data_format) {
- if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
- LaunchDepthwiseConv2dGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight>(
- d, args, input, filter, output, data_format);
- return;
- }
const int num_outputs =
args.batch * args.out_rows * args.out_cols * args.out_depth;
// The compile-time constant version runs faster with a single block.
@@ -641,18 +654,36 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
}
}
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
+ const T* input, const T* filter, T* output,
+ TensorFormat data_format) {
+ if (args.depth_multiplier == 1) {
+ if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
+ LaunchDepthwiseConv2dGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight>(
+ d, args, input, filter, output, data_format);
+ return;
+ }
+
+ LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, 1>(
+ d, args, input, filter, output, data_format);
+ } else {
+ LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, -1>(
+ d, args, input, filter, output, data_format);
+ }
+}
+
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
struct DepthwiseConv2dGPULaunch {
static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input,
const T* filter, T* output, TensorFormat data_format) {
- if (args.filter_rows == 3 && args.filter_cols == 3 &&
- args.depth_multiplier == 1) {
- LaunchDepthwiseConv2dGPU<T, 3, 3, 1>(d, args, input, filter, output,
- data_format);
+ if (args.filter_rows == 3 && args.filter_cols == 3) {
+ LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output,
+ data_format);
} else {
- LaunchDepthwiseConv2dGPU<T, -1, -1, -1>(d, args, input, filter, output,
- data_format);
+ LaunchDepthwiseConv2dGPU<T, -1, -1>(d, args, input, filter, output,
+ data_format);
}
}
};
@@ -726,7 +757,7 @@ __global__ void __launch_bounds__(640, 2)
}
// CUDA kernel to compute the depthwise convolution backward w.r.t. input in
-// NHWC format, tailored for small images up to 16x16. Stride and depth
+// NHWC format, tailored for small images up to 32x32. Stride and depth
// multiplier must be 1. Padding must be 'SAME', which allows to reuse the index
// computation. Only use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args)
// returns true.
@@ -736,7 +767,7 @@ __global__ void __launch_bounds__(640, 2)
// performing the convolution. Each thread handles two elements per iteration,
// one each in the lower and upper half of a tile.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- bool kKnownEvenRows>
+ int kBlockSlices, bool kKnownEvenRows>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
@@ -757,24 +788,23 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall(
const int pad_cols = args.pad_cols;
// Fixed blockDim.x, corresponding to Pascal's global load granularity of 32B.
- const int block_slices = 8;
const int block_rows = blockDim.z;
// These values are the same for all threads and could
// be precomputed on the CPU.
- const int block_size = block_rows * in_cols * block_slices;
+ const int block_size = block_rows * in_cols * kBlockSlices;
const int in_row_size = in_cols * in_depth;
const int in_size = in_rows * in_row_size;
- const int in_increment = (in_cols - 1) * block_slices;
+ const int in_increment = (in_cols - 1) * kBlockSlices;
const int filter_pixels = filter_rows * filter_cols;
const int tile_cols = in_cols + filter_cols - 1;
const int even_rows = kKnownEvenRows || (1 & ~in_rows);
const int tile_rows = in_rows + filter_rows - even_rows;
- const int tile_row_size = tile_cols * block_slices;
+ const int tile_row_size = tile_cols * kBlockSlices;
const int tile_size = tile_rows * tile_row_size;
const int tile_offset = block_rows * tile_row_size;
const int pad_offset = pad_rows * tile_cols + pad_cols;
- const int batch_blocks = (in_depth + block_slices - 1) / block_slices;
+ const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices;
const int in_blocks = batch_blocks * batches;
const int tensor_offset =
kKnownEvenRows ? in_size / 2 : block_rows * in_row_size;
@@ -785,7 +815,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall(
// Position in block.
const int thread_pix = thread_row * in_cols + thread_col;
- const int thread_idx = thread_pix * block_slices + thread_depth;
+ const int thread_idx = thread_pix * kBlockSlices + thread_depth;
// Initialize tile, in particular the padding.
for (int i = thread_idx; i < tile_size; i += block_size) {
@@ -798,17 +828,17 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall(
// Position in (padded) shared memory.
const int data_pix = thread_row * tile_cols + thread_col;
- const int data_idx = data_pix * block_slices + thread_depth;
+ const int data_idx = data_pix * kBlockSlices + thread_depth;
// Position in shared memory, offset by pad_rows / pad_cols.
const int tile_pix = data_pix + pad_offset;
- const int tile_idx = tile_pix * block_slices + thread_depth;
+ const int tile_idx = tile_pix * kBlockSlices + thread_depth;
const int max_depth = in_depth - thread_depth;
const int filter_write_offset =
thread_pix < filter_pixels ? tile_size + thread_idx : 0;
const int filter_read_offset =
- tile_size + filter_pixels * block_slices + thread_depth;
+ tile_size + filter_pixels * kBlockSlices + thread_depth;
const bool skip_second =
!kKnownEvenRows && thread_row + (in_rows & 1) == block_rows;
@@ -816,7 +846,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall(
const int batch = b / batch_blocks;
const int stack = b - batch * batch_blocks;
- const int start_depth = stack * block_slices;
+ const int start_depth = stack * kBlockSlices;
const int filter_offset = tensor_idx + start_depth;
const int inout_offset = batch * in_size + filter_offset;
const bool depth_in_range = start_depth < max_depth;
@@ -844,12 +874,12 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall(
const T* filter_ptr = filter_read_offset + shared_data;
UNROLL for (int r = 0; r < filter_rows; ++r) {
UNROLL for (int c = 0; c < filter_cols; ++c) {
- filter_ptr -= block_slices;
+ filter_ptr -= kBlockSlices;
const T filter_value = *filter_ptr;
const T* const tile_ptr = shared_offset + shared_data;
sum1 += filter_value * tile_ptr[0];
sum2 += filter_value * tile_ptr[tile_offset];
- shared_offset += block_slices;
+ shared_offset += kBlockSlices;
}
shared_offset += in_increment;
}
@@ -937,7 +967,7 @@ __global__ void __launch_bounds__(640, 2)
}
// CUDA kernel to compute the depthwise convolution backward w.r.t. input in
-// NHWC format, tailored for small images up to 16x16. Stride and depth
+// NHWC format, tailored for small images up to 32x32. Stride and depth
// multiplier must be 1. Padding must be 'SAME', which allows to reuse the index
// computation. Only use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args)
// returns true.
@@ -947,7 +977,7 @@ __global__ void __launch_bounds__(640, 2)
// performing the convolution. Each thread handles two elements per iteration,
// one each in the lower and upper half of a tile.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- bool kKnownEvenRows>
+ int kBlockSlices, bool kKnownEvenRows>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
@@ -969,12 +999,11 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall(
// Fixed blockDim.z, tailored for maximum grid size for images of size 16x16.
const int block_rows = blockDim.y;
- const int block_slices = 8;
// These values are the same for all threads and could
// be precomputed on the CPU.
const int block_pixels = in_cols * block_rows;
- const int block_size = block_pixels * block_slices;
+ const int block_size = block_pixels * kBlockSlices;
const int in_pixels = in_cols * in_rows;
const int in_increment = in_cols - 1;
const int filter_pixels = filter_rows * filter_cols;
@@ -982,11 +1011,11 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall(
const int even_rows = kKnownEvenRows || (1 & ~in_rows);
const int tile_rows = in_rows + filter_rows - even_rows;
const int tile_pixels = tile_cols * tile_rows;
- const int tile_size = tile_pixels * block_slices;
+ const int tile_size = tile_pixels * kBlockSlices;
const int tile_offset = block_rows * tile_cols;
const int pad_offset = pad_rows * tile_cols + pad_cols;
const int in_slices = in_depth * batches;
- const int in_blocks = (in_slices + block_slices - 1) / block_slices;
+ const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices;
const int thread_col = threadIdx.x;
const int thread_row = threadIdx.y;
@@ -1013,20 +1042,20 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall(
const int tile_idx = data_idx + pad_offset;
// Filter is always in HWCK format, irrespective of the input/output format.
- const int filter_pix = thread_idx / block_slices;
- const int filter_depth = thread_idx % block_slices;
+ const int filter_pix = thread_idx / kBlockSlices;
+ const int filter_depth = thread_idx % kBlockSlices;
const int filter_idx = filter_pix * in_depth;
const int max_slice = in_slices - thread_depth;
const int filter_write_offset =
filter_pix < filter_pixels ? tile_size + thread_idx : 0;
const int filter_read_offset =
- tile_size + filter_pixels * block_slices + thread_depth;
+ tile_size + filter_pixels * kBlockSlices + thread_depth;
const bool skip_second =
!kKnownEvenRows && thread_row + (in_rows & 1) == block_rows;
for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
- const int slice = b * block_slices;
+ const int slice = b * kBlockSlices;
const int inout_offset = slice * in_pixels + tensor_idx;
const bool slice_in_range = slice < max_slice;
@@ -1055,7 +1084,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall(
const T* filter_ptr = filter_read_offset + shared_data;
UNROLL for (int r = 0; r < filter_rows; ++r) {
UNROLL for (int c = 0; c < filter_cols; ++c) {
- filter_ptr -= block_slices;
+ filter_ptr -= kBlockSlices;
const T filter_value = *filter_ptr;
const T* const tile_ptr = shared_offset + shared_data;
sum1 += filter_value * tile_ptr[0];
@@ -1077,44 +1106,45 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall(
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- bool kKnownEvenRows>
+ int kBlockSlices, bool kKnownEvenRows>
void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d,
const DepthwiseArgs args,
const T* out_backprop,
const T* filter, T* in_backprop,
TensorFormat data_format) {
const int block_rows = (args.in_rows + 1) / 2;
- const int block_slices = 8;
const int tile_cols = args.in_cols + args.filter_cols - 1;
const int tile_rows = block_rows * 2 + args.filter_rows - 1;
const int tile_pixels = tile_rows * tile_cols;
const int filter_pixels = args.filter_rows * args.filter_cols;
const int shared_memory_size =
- block_slices * (tile_pixels + filter_pixels) * sizeof(T);
+ kBlockSlices * (tile_pixels + filter_pixels) * sizeof(T);
const int num_outputs =
args.batch * args.out_rows * args.out_cols * args.out_depth;
if (data_format == FORMAT_NHWC) {
- dim3 block_dim = dim3(block_slices, args.in_cols, block_rows);
+ dim3 block_dim = dim3(kBlockSlices, args.in_cols, block_rows);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_outputs, d,
DepthwiseConv2dBackpropInputGPUKernelNHWCSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>,
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices,
+ kKnownEvenRows>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dBackpropInputGPUKernelNHWCSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kKnownEvenRows>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, filter, in_backprop);
} else if (data_format == FORMAT_NCHW) {
- dim3 block_dim = dim3(args.in_cols, block_rows, block_slices);
+ dim3 block_dim = dim3(args.in_cols, block_rows, kBlockSlices);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_outputs, d,
DepthwiseConv2dBackpropInputGPUKernelNCHWSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>,
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices,
+ kKnownEvenRows>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dBackpropInputGPUKernelNCHWSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kKnownEvenRows>
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kKnownEvenRows>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, filter, in_backprop);
} else {
@@ -1122,25 +1152,48 @@ void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d,
}
}
-template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
+ int kBlockSlices>
void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d,
const DepthwiseArgs args,
const T* out_backprop,
const T* filter, T* in_backprop,
TensorFormat data_format) {
if (args.in_rows & 1) {
- LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight,
- /*kKnownEvenRows=*/false>(
+ LaunchDepthwiseConv2dBackpropInputGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, false>(
d, args, out_backprop, filter, in_backprop, data_format);
} else {
- LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight,
- /*kKnownEvenRows=*/true>(
+ LaunchDepthwiseConv2dBackpropInputGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, true>(
d, args, out_backprop, filter, in_backprop, data_format);
}
}
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d,
+ const DepthwiseArgs args,
+ const T* input, const T* filter,
+ T* output,
+ TensorFormat data_format) {
+ // Maximize (power of two) kBlockSlices while keeping a block within 1024
+ // threads (2 pixels per thread).
+ const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols;
+ if (block_pixels > 256) {
+ LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight, 2>(
+ d, args, input, filter, output, data_format);
+ } else if (block_pixels > 128) {
+ LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight, 4>(
+ d, args, input, filter, output, data_format);
+ } else {
+ LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight, 8>(
+ d, args, input, filter, output, data_format);
+ }
+}
+
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
@@ -1148,12 +1201,6 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
const T* out_backprop,
const T* filter, T* in_backprop,
TensorFormat data_format) {
- if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
- LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>(
- d, args, out_backprop, filter, in_backprop, data_format);
- return;
- }
const int num_in_backprop =
args.batch * args.in_rows * args.in_cols * args.in_depth;
if (data_format == FORMAT_NHWC) {
@@ -1181,22 +1228,41 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
}
}
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
+ const DepthwiseArgs args,
+ const T* out_backprop,
+ const T* filter, T* in_backprop,
+ TensorFormat data_format) {
+ if (args.depth_multiplier == 1) {
+ if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
+ LaunchDepthwiseConv2dBackpropInputGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight>(
+ d, args, out_backprop, filter, in_backprop, data_format);
+ return;
+ }
+
+ LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, 1>(
+ d, args, out_backprop, filter, in_backprop, data_format);
+ } else {
+ LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, -1>(
+ d, args, out_backprop, filter, in_backprop, data_format);
+ }
+}
+
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
struct DepthwiseConv2dBackpropInputGPULaunch {
static void Run(const GpuDevice& d, const DepthwiseArgs args,
const T* out_backprop, const T* filter, T* in_backprop,
TensorFormat data_format) {
- if (args.depth_multiplier == 1) {
- if (args.filter_rows == 3 && args.filter_cols == 3) {
- LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3, 1>(
- d, args, out_backprop, filter, in_backprop, data_format);
- } else {
- LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, 1>(
- d, args, out_backprop, filter, in_backprop, data_format);
- }
+ if (args.filter_rows == 3 && args.filter_cols == 3) {
+ LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
+ d, args, out_backprop, filter, in_backprop, data_format);
} else {
- LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, -1>(
+ LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
d, args, out_backprop, filter, in_backprop, data_format);
}
}
@@ -1300,19 +1366,20 @@ __global__ void __launch_bounds__(640, 2)
}
// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
-// NHWC format, tailored for small images up to 16x16. Stride and depth
+// NHWC format, tailored for small images up to 32x32. Stride and depth
// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
// CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
// Tiles of the input tensor are loaded into shared memory before performing the
// convolution. Per iteration and filter element, each thread first performs
// a partial convolution for two elements, one each in the lower and upper half
-// of a tile. The intermediate result of 4 consecutive columns are then
+// of a tile. The intermediate result of all pixels of a warp are then
// accumulated and written to shared memory. Finally, the values in shared
// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
// up in global memory using atomics.
+// Requirements: threads per block must be multiple of 32 and <= launch_bounds,
+// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockSlices.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- // Requirement: kAccumPixels * 8 >= args.in_rows * args.in_cols
- int kAccumPixels>
+ int kBlockSlices, int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
@@ -1332,29 +1399,29 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
- // Fixed blockDim.x, corresponding to Pascal's global load granularity of 32B.
- const int block_slices = 8;
const int block_rows = blockDim.z;
// These values are the same for all threads and could
// be precomputed on the CPU.
- const int block_size = block_rows * in_cols * block_slices;
+ const int block_size = block_rows * in_cols * kBlockSlices;
+ assert((block_size & 31) == 0);
const int in_row_size = in_cols * in_depth;
const int in_size = in_rows * in_row_size;
- const int in_increment = (in_cols - 1) * block_slices;
+ const int in_increment = (in_cols - 1) * kBlockSlices;
const int filter_pixels = filter_rows * filter_cols;
const int tile_cols = in_cols + filter_cols - 1;
const int tile_rows = 2 * block_rows + filter_rows - 1;
- const int tile_row_size = tile_cols * block_slices;
+ const int tile_row_size = tile_cols * kBlockSlices;
const int tile_size = tile_rows * tile_row_size;
const int tile_offset = block_rows * tile_row_size;
const int pad_offset = pad_rows * tile_cols + pad_cols;
- const int batch_blocks = (in_depth + block_slices - 1) / block_slices;
+ const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices;
const int in_blocks = batch_blocks * batches;
const int tensor_offset = block_rows * in_row_size;
// The accumulator has a fixed number of pixels that can be reduced by one
- // warp. Pixels beyond block_pixels/4 are never written.
- const int accum_increment = kAccumPixels * block_slices;
+ // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written.
+ assert(kAccumPixels * 64 >= in_rows * in_cols * kBlockSlices);
+ const int accum_increment = kAccumPixels * kBlockSlices;
const int accum_size = filter_pixels * accum_increment;
const int thread_depth = threadIdx.x;
@@ -1363,7 +1430,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Position in block.
const int thread_pix = thread_row * in_cols + thread_col;
- const int thread_idx = thread_pix * block_slices + thread_depth;
+ const int thread_idx = thread_pix * kBlockSlices + thread_depth;
// Initialize tile, in particular the padding and accumulator.
for (int i = thread_idx; i < tile_size + accum_size; i += block_size) {
@@ -1376,14 +1443,14 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Position in (padded) shared memory.
const int data_pix = thread_row * tile_cols + thread_col;
- const int data_idx = data_pix * block_slices + thread_depth;
+ const int data_idx = data_pix * kBlockSlices + thread_depth;
// Position in shared memory, offset by pad_rows / pad_cols.
const int tile_pix = data_pix + pad_offset;
- const int tile_idx = tile_pix * block_slices + thread_depth;
+ const int tile_idx = tile_pix * kBlockSlices + thread_depth;
- // Position in accumulator (1 per 4 threads, depth major).
- const int accum_pix = thread_pix / 4;
+ // Position in accumulator (kBlockSlices per warp, depth major).
+ const int accum_pix = thread_pix / (32 / kBlockSlices);
const int accum_idx = thread_depth * kAccumPixels + accum_pix;
const int max_depth = in_depth - thread_depth;
@@ -1394,7 +1461,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const int batch = b / batch_blocks;
const int stack = b - batch * batch_blocks;
- const int start_depth = stack * block_slices;
+ const int start_depth = stack * kBlockSlices;
const int filter_offset = tensor_idx + start_depth;
const int inout_offset = batch * in_size + filter_offset;
const bool depth_in_range = start_depth < max_depth;
@@ -1421,13 +1488,14 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
UNROLL for (int c = 0; c < filter_cols; ++c) {
const T* const tile_ptr = shared_offset + shared_data;
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
- // Sum up 4 block_pixels of the same depth and write to accumulator.
- val += CudaShuffleDown(val, 16);
- val += CudaShuffleDown(val, 8);
- if (!(thread_idx & 24) /* i.e. 'lane_idx < 8' */) {
+ // Warp-accumulate pixels of the same depth and write to accumulator.
+ for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
+ val += CudaShuffleDown(val, delta);
+ }
+ if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
}
- shared_offset += block_slices;
+ shared_offset += kBlockSlices;
accum_ptr += accum_increment;
}
shared_offset += in_increment;
@@ -1440,8 +1508,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const T* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
const int filter_idx = i / kAccumPixels;
- const int filter_pix = filter_idx / block_slices;
- const int filter_depth = filter_idx % block_slices + start_depth;
+ const int filter_pix = filter_idx / kBlockSlices;
+ const int filter_depth = filter_idx % kBlockSlices + start_depth;
const int filter_offset = filter_pix * in_depth + filter_depth;
if (filter_depth < in_depth) {
T val = accum_data[i];
@@ -1558,19 +1626,20 @@ __global__ void __launch_bounds__(640, 2)
}
// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
-// NCHW format, tailored for small images up to 16x16. Stride and depth
+// NCHW format, tailored for small images up to 32x32. Stride and depth
// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
// CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
// Tiles of the input tensor are loaded into shared memory before performing the
// convolution. Per iteration and filter element, each thread first performs
// a partial convolution for two elements, one each in the lower and upper half
-// of a tile. The intermediate result of 4 consecutive columns are then
+// of a tile. The intermediate result of all pixels of a warp are then
// accumulated and written to shared memory. Finally, the values in shared
// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
// up in global memory using atomics.
+// Requirements: threads per block must be multiple of 32 and <= launch_bounds,
+// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockSlices.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- // Requirement: kAccumPixels * 8 >= args.in_rows * args.in_cols
- int kAccumPixels>
+ int kBlockSlices, int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
@@ -1590,28 +1659,28 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
- // Fixed blockDim.z, corresponding to Pascal's global load granularity of 32B.
const int block_rows = blockDim.y;
- const int block_slices = 8;
// These values are the same for all threads and could
// be precomputed on the CPU.
const int block_pixels = in_cols * block_rows;
- const int block_size = block_pixels * block_slices;
+ const int block_size = block_pixels * kBlockSlices;
+ assert((block_size & 31) == 0);
const int in_pixels = in_cols * in_rows;
const int in_increment = in_cols - 1;
const int filter_pixels = filter_rows * filter_cols;
const int tile_cols = in_cols + filter_cols - 1;
const int tile_rows = 2 * block_rows + filter_rows - 1;
const int tile_pixels = tile_cols * tile_rows;
- const int tile_size = tile_pixels * block_slices;
+ const int tile_size = tile_pixels * kBlockSlices;
const int tile_offset = block_rows * tile_cols;
const int pad_offset = pad_rows * tile_cols + pad_cols;
const int in_slices = in_depth * batches;
- const int in_blocks = (in_slices + block_slices - 1) / block_slices;
+ const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices;
// The accumulator has a fixed number of pixels that can be reduced by one
- // warp. Pixels beyond block_pixels/4 are never written.
- const int accum_increment = kAccumPixels * block_slices;
+ // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written.
+ assert(kAccumPixels * 64 >= in_rows * in_cols * kBlockSlices);
+ const int accum_increment = kAccumPixels * kBlockSlices;
const int accum_size = filter_pixels * accum_increment;
const int thread_col = threadIdx.x;
@@ -1638,8 +1707,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Position in shared memory, offset by pad_rows / pad_cols.
const int tile_idx = data_idx + pad_offset;
- // Position in accumulator (1 per 4 threads, depth major).
- const int accum_pix = thread_pix / 4;
+ // Position in accumulator (kBlockSlices per warp, depth major).
+ const int accum_pix = thread_pix / (32 / kBlockSlices);
const int accum_idx = thread_depth * kAccumPixels + accum_pix;
const int max_slice = in_slices - thread_depth;
@@ -1647,7 +1716,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const bool skip_second = block_rows + thread_row >= in_rows;
for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
- const int slice = b * block_slices;
+ const int slice = b * kBlockSlices;
const int inout_offset = slice * in_pixels + tensor_idx;
const bool slice_in_range = slice < max_slice;
@@ -1674,10 +1743,11 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
UNROLL for (int c = 0; c < filter_cols; ++c) {
const T* const tile_ptr = shared_offset + shared_data;
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
- // Sum up 4 block_pixels of the same depth and write to accumulator.
- val += CudaShuffleDown(val, 2);
- val += CudaShuffleDown(val, 1);
- if (!(thread_idx & 3)) {
+ // Warp-accumulate pixels of the same depth and write to accumulator.
+ for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
+ val += CudaShuffleDown(val, delta);
+ }
+ if (!(thread_idx & 32 / kBlockSlices - 1)) {
*accum_ptr = val;
}
++shared_offset;
@@ -1693,8 +1763,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const T* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
const int filter_idx = i / kAccumPixels;
- const int filter_pix = filter_idx / block_slices;
- const int filter_depth = (slice + filter_idx % block_slices) % in_depth;
+ const int filter_pix = filter_idx / kBlockSlices;
+ const int filter_depth = (slice + filter_idx % kBlockSlices) % in_depth;
const int filter_offset = filter_pix * in_depth + filter_depth;
if (filter_depth < in_depth) {
T val = accum_data[i];
@@ -1711,87 +1781,121 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- int kAccumPixels>
-void LaunchDepthwiseConv2dBackpropFilterGPUSmall(
- const GpuDevice& d, const DepthwiseArgs args, int block_rows,
- int shared_memory_size, const T* out_backprop, const T* input,
- T* filter_backprop, TensorFormat data_format) {
- const int block_slices = 8;
+ int kBlockSlices, int kAccumPixels>
+bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ const GpuDevice& d, const DepthwiseArgs args, const int block_rows,
+ const T* out_backprop, const T* input, T* filter_backprop,
+ TensorFormat data_format) {
+ const int tile_cols = args.in_cols + args.filter_cols - 1;
+ const int tile_rows = block_rows * 2 + args.filter_rows - 1;
+ const int tile_pixels = tile_rows * tile_cols;
+ const int filter_pixels = args.filter_rows * args.filter_cols;
+ const int shared_memory_size =
+ kBlockSlices * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T);
+ if (shared_memory_size > d.sharedMemPerBlock()) {
+ return false;
+ }
+
const int num_out_backprop =
args.batch * args.out_rows * args.out_cols * args.out_depth;
if (data_format == FORMAT_NHWC) {
- dim3 block_dim = dim3(block_slices, args.in_cols, block_rows);
+ dim3 block_dim = dim3(kBlockSlices, args.in_cols, block_rows);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_out_backprop, d,
DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>,
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices,
+ kAccumPixels>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, input, filter_backprop);
} else if (data_format == FORMAT_NCHW) {
- dim3 block_dim = dim3(args.in_cols, block_rows, block_slices);
+ dim3 block_dim = dim3(args.in_cols, block_rows, kBlockSlices);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_out_backprop, d,
DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>,
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices,
+ kAccumPixels>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, input, filter_backprop);
} else {
assert(false && "Incorrect data format");
}
+ return true;
+}
+
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
+ int kBlockSlices>
+bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ const GpuDevice& d, const DepthwiseArgs args, const int block_rows,
+ const T* out_backprop, const T* input, T* filter_backprop,
+ TensorFormat data_format) {
+ // Minimize (power of two) kAccumPixels, while satisfying
+ // kAccumPixels * 32 >= block_rows * in_cols * kBlockSlices.
+ const int block_pixels = block_rows * args.in_cols * kBlockSlices;
+ if (block_pixels > 512) {
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 32>(
+ d, args, block_rows, out_backprop, input, filter_backprop, data_format);
+ } else if (block_pixels > 256) {
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 16>(
+ d, args, block_rows, out_backprop, input, filter_backprop, data_format);
+ } else {
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 8>(
+ d, args, block_rows, out_backprop, input, filter_backprop, data_format);
+ }
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop,
const T* input, T* filter_backprop, TensorFormat data_format) {
- // args.in_cols * blocks_rows (block_pixels) must be multiple of 4.
- const int lookup_table[] = {0, 3, 1, 3};
- const int rows_mask = lookup_table[args.in_cols & 3];
- const int block_rows = (args.in_rows + 1) / 2 + rows_mask & ~rows_mask;
- if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) {
- return false;
- }
-
- const int in_pixels = args.in_rows * args.in_cols;
- int accum_pixels = 8;
- while (accum_pixels * 8 < in_pixels) {
- accum_pixels *= 2;
+ // Maximize (power of two) kBlockSlices while keeping a block within 1024
+ // threads (2 pixels per thread).
+ int block_slices = 8;
+ int block_rows = (args.in_rows + 1) / 2;
+ int round_mask = 1;
+ for (; block_slices > 1; block_slices /= 2) {
+ // args.in_cols * block_rows * kBlockSlices must be multiple of 32.
+ for (; block_rows * args.in_cols * block_slices & 31;
+ round_mask = round_mask * 2 + 1) {
+ block_rows = block_rows + round_mask & ~round_mask;
+ }
+ int block_size = block_rows * args.in_cols * block_slices;
+ if (block_size <= 1024) {
+ break;
+ }
}
- const int block_slices = 8;
- const int tile_cols = args.in_cols + args.filter_cols - 1;
- const int tile_rows = block_rows * 2 + args.filter_rows - 1;
- const int tile_pixels = tile_rows * tile_cols;
- const int filter_pixels = args.filter_rows * args.filter_cols;
- const int shared_memory_size =
- block_slices * (tile_pixels + filter_pixels * accum_pixels) * sizeof(T);
- if (shared_memory_size > d.sharedMemPerBlock()) {
+ if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) {
return false;
}
- if (accum_pixels == 8) {
- LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight, 8>(
- d, args, block_rows, shared_memory_size, out_backprop, input,
- filter_backprop, data_format);
- } else if (accum_pixels == 16) {
- LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight, 16>(
- d, args, block_rows, shared_memory_size, out_backprop, input,
- filter_backprop, data_format);
- } else {
- LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight, 32>(
- d, args, block_rows, shared_memory_size, out_backprop, input,
- filter_backprop, data_format);
+ switch (block_slices) {
+ case 8:
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, 8>(
+ d, args, block_rows, out_backprop, input, filter_backprop,
+ data_format);
+ case 4:
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, 4>(
+ d, args, block_rows, out_backprop, input, filter_backprop,
+ data_format);
+ case 2:
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, 2>(
+ d, args, block_rows, out_backprop, input, filter_backprop,
+ data_format);
+ default:
+ return false;
}
- return true;
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
@@ -1801,11 +1905,6 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
const T* out_backprop,
const T* input, T* filter_backprop,
TensorFormat data_format) {
- if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
- kKnownFilterHeight>(
- d, args, out_backprop, input, filter_backprop, data_format)) {
- return;
- }
const int num_out_backprop =
args.batch * args.out_rows * args.out_cols * args.out_depth;
if (data_format == FORMAT_NHWC) {
@@ -1833,18 +1932,40 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
}
}
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
+void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
+ const DepthwiseArgs args,
+ const T* out_backprop,
+ const T* input, T* filter_backprop,
+ TensorFormat data_format) {
+ if (args.depth_multiplier == 1) {
+ if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
+ kKnownFilterHeight>(
+ d, args, out_backprop, input, filter_backprop, data_format)) {
+ return;
+ }
+
+ LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, 1>(
+ d, args, out_backprop, input, filter_backprop, data_format);
+ } else {
+ LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, -1>(
+ d, args, out_backprop, input, filter_backprop, data_format);
+ }
+}
+
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
struct DepthwiseConv2dBackpropFilterGPULaunch {
static void Run(const GpuDevice& d, const DepthwiseArgs args,
const T* out_backprop, const T* input, T* filter_backprop,
TensorFormat data_format) {
- if (args.filter_rows == 3 && args.filter_cols == 3 &&
- args.depth_multiplier == 1) {
- LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3, 1>(
+ if (args.filter_rows == 3 && args.filter_cols == 3) {
+ LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
d, args, out_backprop, input, filter_backprop, data_format);
} else {
- LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1, -1>(
+ LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
d, args, out_backprop, input, filter_backprop, data_format);
}
}
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 8ba9d0efff..3298092fbe 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -37,18 +37,21 @@ def ConfigsToTest():
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
convolution parameters.
"""
- input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 35, 35, 2],
- [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]]
- filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [5, 5, 2, 1],
- [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]]
- out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 35, 35, 2],
- [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]]
- strides = [1, 1, 1, 1, 3, 2, 2]
+ input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
+ [4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
+ [3, 299, 299, 3], [5, 183, 183, 1]]
+ filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
+ [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
+ 8], [5, 5, 1, 2]]
+ out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
+ [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
+ [3, 150, 150, 24], [5, 92, 92, 2]]
+ strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
- paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
+ paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
paddings):
yield i, f, o, s, p