diff options
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/maxpooling_op_gpu.cu.cc | 264 |
1 files changed, 192 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index 91b50b1e11..0c638ca233 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -199,32 +199,145 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff, } } -#undef CUDA_1D_KERNEL_LOOP -} // namespace +// The parameters to the kernels in the gradient gradient function is as +// follows: +// nthreads: the number of threads, which is equal to the output size. The +// gradient of the MaxPooling gradient w.r.t. the output data has a +// dimensions of N*C*Hout*Wout +// bottom_data: the bottom data of N*H*W*C (or N*C*H*W) items. +// output_data: the output data of N*Hout*Wout*C (or N*C*Hout*Wout) items. +// height, width, pooled_height, pooled_width: the input and output sizes. +// kernel_h, kernel_w: the kernel sizes. +// stride_h, stride_w: the strides. +// pad_t, pad_l: the padding values on the top and left side. +// top_diff: the gradient of the gradient of the output data w.r.t. the +// input data, of size N*H*W*C (or N*C*H*W). +// bottom_diff: the gradient of the gradient w.r.t. output. +template <typename dtype> +__global__ void MaxPoolGradBackwardNoMaskNCHW( + const int nthreads, const dtype* bottom_data, const dtype* output_data, + const int pooled_height, const int pooled_width, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + const dtype* top_diff, dtype* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // First find out the index to the maximum, since we have no mask. + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + const int hend = min(hstart + kernel_h, height); + const int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + bool should_stop = false; + int maxidx = -1; + const dtype* bottom_data_n = bottom_data + n * channels * height * width; + // Propagate only first value from top_diff corresponding to the maximum. + for (int h = hstart; h < hend && !should_stop; ++h) { + for (int w = wstart; w < wend && !should_stop; ++w) { + int idx = c * height * width + h * width + w; + if (output_data[index] == bottom_data_n[idx]) { + maxidx = idx; + should_stop = true; + } + } + } + // Set the bottom diff (atomic is not necessary). The index could still be + // uninitialized, if all the bottom_data are NaN. + if (maxidx != -1) { + bottom_diff[index] = top_diff[n * channels * height * width + maxidx]; + } + } +} -bool MaxPoolForwardWithOptionalArgmax( - const float* bottom_data, const int batch, const int height, - const int width, const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, const int kernel_w, +template <typename dtype> +__global__ void MaxPoolGradBackwardNoMaskNHWC( + const int nthreads, const dtype* bottom_data, const dtype* output_data, + const int pooled_height, const int pooled_width, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, - float* top_data, int64* mask, const Eigen::GpuDevice& d) { - const int kThreadsPerBlock = 1024; - const int output_size = batch * channels * pooled_height * pooled_width; + const dtype* top_diff, dtype* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // First find out the index to the maximum, since we have no mask. + int n = index; + int c = n % channels; + n /= channels; + int wstart = (n % pooled_width) * stride_w - pad_l; + n /= pooled_width; + int hstart = (n % pooled_height) * stride_h - pad_t; + n /= pooled_height; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + bool should_stop = false; + int maxidx = -1; + const dtype* bottom_data_n = bottom_data + n * height * width * channels; + // Propagate only first value from top_diff corresponding to the maximum. + for (int h = hstart; h < hend && !should_stop; ++h) { + for (int w = wstart; w < wend && !should_stop; ++w) { + int idx = (h * width + w) * channels + c; + if (output_data[index] == bottom_data_n[idx]) { + maxidx = idx; + should_stop = true; + } + } + } + // Set the bottom diff (atomic is not necessary). The index could still be + // uninitialized, if all the bottom_data are NaN. + if (maxidx != -1) { + bottom_diff[index] = top_diff[n * height * width * channels + maxidx]; + } + } +} - MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - output_size, bottom_data, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_data, mask); - return d.ok(); +// The parameters to the kernels in the gradient gradient function is as +// follows: +// nthreads: the number of threads, which is equal to the output size. The +// gradient of the MaxPooling gradient w.r.t. the output data has a +// dimensions of N*C*Hout*Wout +// top_diff: the gradient of the gradient of the output data w.r.t. the +// input data, of size N*H*W*C (or N*C*H*W). As we have stored the +// flattened index of the input entries, the backward function is +// agnostic of the input storage order. +// mask: the output mask of the same size as top_data. It is stored in +// int form, keeping track of the flattened index of the input item that +// produces the max output. +// top_offset: the pre-computed per-image offset of the maxpool input +// gradient. This is equal to H*W*C. We choose to pre-compute this so we +// do not need to compute it every time inside the kernel. +// bottom_offset: the pre-computed per-image offset of the maxpool output. +// This is equal to Hout*Wout*C. +// bottom_diff: the gradient of the gradient w.r.t. output. +// This function relies on CudaAtomicAdd to avoid race conditions. Also, before +// the kernel is run, you will need to make sure that bottom_diff is filled with +// zero first. +template <typename dtype> +__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff, + const int64* mask, const int top_offset, + const int bottom_offset, + dtype* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int image_id = (index / bottom_offset); + bottom_diff[index] = top_diff[image_id * top_offset + mask[index]]; + } } -bool MaxPoolForwardWithOptionalArgmax( - const Eigen::half* bottom_data, const int batch, const int height, - const int width, const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, const int kernel_w, - const int stride_h, const int stride_w, const int pad_t, const int pad_l, - Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d) { +#undef CUDA_1D_KERNEL_LOOP +} // namespace + +namespace functor { + +template <typename T> +bool MaxPoolForwardWithOptionalArgmax<T>::operator()( + const T* bottom_data, const int batch, const int height, const int width, + const int channels, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, T* top_data, + int64* mask, const Eigen::GpuDevice& d) { const int kThreadsPerBlock = 1024; const int output_size = batch * channels * pooled_height * pooled_width; @@ -236,14 +349,13 @@ bool MaxPoolForwardWithOptionalArgmax( return d.ok(); } -bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, - const int height, const int width, - const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, - const int stride_w, const int pad_t, const int pad_l, - const float* top_diff, float* bottom_diff, - const Eigen::GpuDevice& d) { +template <typename T> +bool MaxPoolBackwardNoMask<T>::operator()( + const T* bottom_data, const int batch, const int height, const int width, + const int channels, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, const T* top_diff, + T* bottom_diff, const Eigen::GpuDevice& d) { const int kThreadsPerBlock = 1024; const int bottom_size = batch * channels * height * width; const int top_size = batch * channels * pooled_height * pooled_width; @@ -260,34 +372,11 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, return d.ok(); } -bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch, - const int height, const int width, - const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, - const int stride_w, const int pad_t, const int pad_l, - const Eigen::half* top_diff, Eigen::half* bottom_diff, - const Eigen::GpuDevice& d) { - const int kThreadsPerBlock = 1024; - const int bottom_size = batch * channels * height * width; - const int top_size = batch * channels * pooled_height * pooled_width; - - SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff); - - MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) / - kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - top_size, bottom_data, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_diff, bottom_diff); - return d.ok(); -} - -bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, - const float* top_diff, const int64* mask, - const int top_offset, const int bottom_offset, - float* bottom_diff, const Eigen::GpuDevice& d) { +template <typename T> +bool MaxPoolBackwardWithArgmax<T>::operator()( + const int output_size, const int input_size, const T* top_diff, + const int64* mask, const int top_offset, const int bottom_offset, + T* bottom_diff, const Eigen::GpuDevice& d) { const int kThreadsPerBlock = 1024; SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); @@ -297,30 +386,61 @@ bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, return d.ok(); } -bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, - const Eigen::half* top_diff, const int64* mask, - const int top_offset, const int bottom_offset, - Eigen::half* bottom_diff, - const Eigen::GpuDevice& d) { - const int kThreadsPerBlock = 1024; - SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); - MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff); +template <typename T> +bool MaxPoolGradBackwardNoMask<T>::operator()( + TensorFormat data_format, const T* bottom_data, const T* output_data, + const int batch, const int pooled_height, const int pooled_width, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, const int pad_t, + const int pad_l, const T* top_diff, T* bottom_diff, + const Eigen::GpuDevice& d) { + const int num_kernels = batch * channels * pooled_height * pooled_width; + CudaLaunchConfig config = GetCudaLaunchConfig(num_kernels, d); + + if (data_format == FORMAT_NHWC) { + MaxPoolGradBackwardNoMaskNHWC<<<config.block_count, config.thread_per_block, + 0, d.stream()>>>( + num_kernels, bottom_data, output_data, pooled_height, pooled_width, + channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_t, + pad_l, top_diff, bottom_diff); + } else { + MaxPoolGradBackwardNoMaskNCHW<<<config.block_count, config.thread_per_block, + 0, d.stream()>>>( + num_kernels, bottom_data, output_data, pooled_height, pooled_width, + channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_t, + pad_l, top_diff, bottom_diff); + } + return d.ok(); +} + +template <typename T> +bool MaxPoolGradBackwardWithArgmax<T>::operator()( + const int output_size, const int input_size, const T* top_diff, + const int64* mask, const int top_offset, const int bottom_offset, + T* bottom_diff, const Eigen::GpuDevice& d) { + CudaLaunchConfig config = GetCudaLaunchConfig(output_size, d); + MaxPoolGradBackward<<<config.block_count, config.thread_per_block, 0, + d.stream()>>>(output_size, top_diff, mask, top_offset, + bottom_offset, bottom_diff); return d.ok(); } typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_KERNELS(T) \ - template struct functor::SpatialMaxPooling<GPUDevice, T>; +#define DEFINE_GPU_KERNELS(T) \ + template struct SpatialMaxPooling<GPUDevice, T>; \ + template struct MaxPoolForwardWithOptionalArgmax<T>; \ + template struct MaxPoolBackwardWithArgmax<T>; \ + template struct MaxPoolBackwardNoMask<T>; \ + template struct MaxPoolGradBackwardWithArgmax<T>; \ + template struct MaxPoolGradBackwardNoMask<T>; -DEFINE_GPU_KERNELS(float) -DEFINE_GPU_KERNELS(Eigen::half) +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); #undef DEFINE_GPU_KERNELS +} // namespace functor + } // end namespace tensorflow #endif // GOOGLE_CUDA |