aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc264
1 files changed, 192 insertions, 72 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 91b50b1e11..0c638ca233 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -199,32 +199,145 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
}
}
-#undef CUDA_1D_KERNEL_LOOP
-} // namespace
+// The parameters to the kernels in the gradient gradient function is as
+// follows:
+// nthreads: the number of threads, which is equal to the output size. The
+// gradient of the MaxPooling gradient w.r.t. the output data has a
+// dimensions of N*C*Hout*Wout
+// bottom_data: the bottom data of N*H*W*C (or N*C*H*W) items.
+// output_data: the output data of N*Hout*Wout*C (or N*C*Hout*Wout) items.
+// height, width, pooled_height, pooled_width: the input and output sizes.
+// kernel_h, kernel_w: the kernel sizes.
+// stride_h, stride_w: the strides.
+// pad_t, pad_l: the padding values on the top and left side.
+// top_diff: the gradient of the gradient of the output data w.r.t. the
+// input data, of size N*H*W*C (or N*C*H*W).
+// bottom_diff: the gradient of the gradient w.r.t. output.
+template <typename dtype>
+__global__ void MaxPoolGradBackwardNoMaskNCHW(
+ const int nthreads, const dtype* bottom_data, const dtype* output_data,
+ const int pooled_height, const int pooled_width, const int channels,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ const dtype* top_diff, dtype* bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // First find out the index to the maximum, since we have no mask.
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+ int hstart = ph * stride_h - pad_t;
+ int wstart = pw * stride_w - pad_l;
+ const int hend = min(hstart + kernel_h, height);
+ const int wend = min(wstart + kernel_w, width);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ bool should_stop = false;
+ int maxidx = -1;
+ const dtype* bottom_data_n = bottom_data + n * channels * height * width;
+ // Propagate only first value from top_diff corresponding to the maximum.
+ for (int h = hstart; h < hend && !should_stop; ++h) {
+ for (int w = wstart; w < wend && !should_stop; ++w) {
+ int idx = c * height * width + h * width + w;
+ if (output_data[index] == bottom_data_n[idx]) {
+ maxidx = idx;
+ should_stop = true;
+ }
+ }
+ }
+ // Set the bottom diff (atomic is not necessary). The index could still be
+ // uninitialized, if all the bottom_data are NaN.
+ if (maxidx != -1) {
+ bottom_diff[index] = top_diff[n * channels * height * width + maxidx];
+ }
+ }
+}
-bool MaxPoolForwardWithOptionalArgmax(
- const float* bottom_data, const int batch, const int height,
- const int width, const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h, const int kernel_w,
+template <typename dtype>
+__global__ void MaxPoolGradBackwardNoMaskNHWC(
+ const int nthreads, const dtype* bottom_data, const dtype* output_data,
+ const int pooled_height, const int pooled_width, const int channels,
+ const int height, const int width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
- float* top_data, int64* mask, const Eigen::GpuDevice& d) {
- const int kThreadsPerBlock = 1024;
- const int output_size = batch * channels * pooled_height * pooled_width;
+ const dtype* top_diff, dtype* bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // First find out the index to the maximum, since we have no mask.
+ int n = index;
+ int c = n % channels;
+ n /= channels;
+ int wstart = (n % pooled_width) * stride_w - pad_l;
+ n /= pooled_width;
+ int hstart = (n % pooled_height) * stride_h - pad_t;
+ n /= pooled_height;
+ int hend = min(hstart + kernel_h, height);
+ int wend = min(wstart + kernel_w, width);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ bool should_stop = false;
+ int maxidx = -1;
+ const dtype* bottom_data_n = bottom_data + n * height * width * channels;
+ // Propagate only first value from top_diff corresponding to the maximum.
+ for (int h = hstart; h < hend && !should_stop; ++h) {
+ for (int w = wstart; w < wend && !should_stop; ++w) {
+ int idx = (h * width + w) * channels + c;
+ if (output_data[index] == bottom_data_n[idx]) {
+ maxidx = idx;
+ should_stop = true;
+ }
+ }
+ }
+ // Set the bottom diff (atomic is not necessary). The index could still be
+ // uninitialized, if all the bottom_data are NaN.
+ if (maxidx != -1) {
+ bottom_diff[index] = top_diff[n * height * width * channels + maxidx];
+ }
+ }
+}
- MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
- kThreadsPerBlock, 0, d.stream()>>>(
- output_size, bottom_data, height, width, channels, pooled_height,
- pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
- top_data, mask);
- return d.ok();
+// The parameters to the kernels in the gradient gradient function is as
+// follows:
+// nthreads: the number of threads, which is equal to the output size. The
+// gradient of the MaxPooling gradient w.r.t. the output data has a
+// dimensions of N*C*Hout*Wout
+// top_diff: the gradient of the gradient of the output data w.r.t. the
+// input data, of size N*H*W*C (or N*C*H*W). As we have stored the
+// flattened index of the input entries, the backward function is
+// agnostic of the input storage order.
+// mask: the output mask of the same size as top_data. It is stored in
+// int form, keeping track of the flattened index of the input item that
+// produces the max output.
+// top_offset: the pre-computed per-image offset of the maxpool input
+// gradient. This is equal to H*W*C. We choose to pre-compute this so we
+// do not need to compute it every time inside the kernel.
+// bottom_offset: the pre-computed per-image offset of the maxpool output.
+// This is equal to Hout*Wout*C.
+// bottom_diff: the gradient of the gradient w.r.t. output.
+// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
+// the kernel is run, you will need to make sure that bottom_diff is filled with
+// zero first.
+template <typename dtype>
+__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
+ const int64* mask, const int top_offset,
+ const int bottom_offset,
+ dtype* bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int image_id = (index / bottom_offset);
+ bottom_diff[index] = top_diff[image_id * top_offset + mask[index]];
+ }
}
-bool MaxPoolForwardWithOptionalArgmax(
- const Eigen::half* bottom_data, const int batch, const int height,
- const int width, const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h, const int kernel_w,
- const int stride_h, const int stride_w, const int pad_t, const int pad_l,
- Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d) {
+#undef CUDA_1D_KERNEL_LOOP
+} // namespace
+
+namespace functor {
+
+template <typename T>
+bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
+ const T* bottom_data, const int batch, const int height, const int width,
+ const int channels, const int pooled_height, const int pooled_width,
+ const int kernel_h, const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l, T* top_data,
+ int64* mask, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
const int output_size = batch * channels * pooled_height * pooled_width;
@@ -236,14 +349,13 @@ bool MaxPoolForwardWithOptionalArgmax(
return d.ok();
}
-bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
- const int height, const int width,
- const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h,
- const int kernel_w, const int stride_h,
- const int stride_w, const int pad_t, const int pad_l,
- const float* top_diff, float* bottom_diff,
- const Eigen::GpuDevice& d) {
+template <typename T>
+bool MaxPoolBackwardNoMask<T>::operator()(
+ const T* bottom_data, const int batch, const int height, const int width,
+ const int channels, const int pooled_height, const int pooled_width,
+ const int kernel_h, const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l, const T* top_diff,
+ T* bottom_diff, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
const int bottom_size = batch * channels * height * width;
const int top_size = batch * channels * pooled_height * pooled_width;
@@ -260,34 +372,11 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
return d.ok();
}
-bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch,
- const int height, const int width,
- const int channels, const int pooled_height,
- const int pooled_width, const int kernel_h,
- const int kernel_w, const int stride_h,
- const int stride_w, const int pad_t, const int pad_l,
- const Eigen::half* top_diff, Eigen::half* bottom_diff,
- const Eigen::GpuDevice& d) {
- const int kThreadsPerBlock = 1024;
- const int bottom_size = batch * channels * height * width;
- const int top_size = batch * channels * pooled_height * pooled_width;
-
- SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
- kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
-
- MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
- kThreadsPerBlock,
- kThreadsPerBlock, 0, d.stream()>>>(
- top_size, bottom_data, height, width, channels, pooled_height,
- pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
- top_diff, bottom_diff);
- return d.ok();
-}
-
-bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
- const float* top_diff, const int64* mask,
- const int top_offset, const int bottom_offset,
- float* bottom_diff, const Eigen::GpuDevice& d) {
+template <typename T>
+bool MaxPoolBackwardWithArgmax<T>::operator()(
+ const int output_size, const int input_size, const T* top_diff,
+ const int64* mask, const int top_offset, const int bottom_offset,
+ T* bottom_diff, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
@@ -297,30 +386,61 @@ bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
return d.ok();
}
-bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
- const Eigen::half* top_diff, const int64* mask,
- const int top_offset, const int bottom_offset,
- Eigen::half* bottom_diff,
- const Eigen::GpuDevice& d) {
- const int kThreadsPerBlock = 1024;
- SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
- kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
- MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
- kThreadsPerBlock, 0, d.stream()>>>(
- output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff);
+template <typename T>
+bool MaxPoolGradBackwardNoMask<T>::operator()(
+ TensorFormat data_format, const T* bottom_data, const T* output_data,
+ const int batch, const int pooled_height, const int pooled_width,
+ const int channels, const int height, const int width, const int kernel_h,
+ const int kernel_w, const int stride_h, const int stride_w, const int pad_t,
+ const int pad_l, const T* top_diff, T* bottom_diff,
+ const Eigen::GpuDevice& d) {
+ const int num_kernels = batch * channels * pooled_height * pooled_width;
+ CudaLaunchConfig config = GetCudaLaunchConfig(num_kernels, d);
+
+ if (data_format == FORMAT_NHWC) {
+ MaxPoolGradBackwardNoMaskNHWC<<<config.block_count, config.thread_per_block,
+ 0, d.stream()>>>(
+ num_kernels, bottom_data, output_data, pooled_height, pooled_width,
+ channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_t,
+ pad_l, top_diff, bottom_diff);
+ } else {
+ MaxPoolGradBackwardNoMaskNCHW<<<config.block_count, config.thread_per_block,
+ 0, d.stream()>>>(
+ num_kernels, bottom_data, output_data, pooled_height, pooled_width,
+ channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_t,
+ pad_l, top_diff, bottom_diff);
+ }
+ return d.ok();
+}
+
+template <typename T>
+bool MaxPoolGradBackwardWithArgmax<T>::operator()(
+ const int output_size, const int input_size, const T* top_diff,
+ const int64* mask, const int top_offset, const int bottom_offset,
+ T* bottom_diff, const Eigen::GpuDevice& d) {
+ CudaLaunchConfig config = GetCudaLaunchConfig(output_size, d);
+ MaxPoolGradBackward<<<config.block_count, config.thread_per_block, 0,
+ d.stream()>>>(output_size, top_diff, mask, top_offset,
+ bottom_offset, bottom_diff);
return d.ok();
}
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_KERNELS(T) \
- template struct functor::SpatialMaxPooling<GPUDevice, T>;
+#define DEFINE_GPU_KERNELS(T) \
+ template struct SpatialMaxPooling<GPUDevice, T>; \
+ template struct MaxPoolForwardWithOptionalArgmax<T>; \
+ template struct MaxPoolBackwardWithArgmax<T>; \
+ template struct MaxPoolBackwardNoMask<T>; \
+ template struct MaxPoolGradBackwardWithArgmax<T>; \
+ template struct MaxPoolGradBackwardNoMask<T>;
-DEFINE_GPU_KERNELS(float)
-DEFINE_GPU_KERNELS(Eigen::half)
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
+} // namespace functor
+
} // end namespace tensorflow
#endif // GOOGLE_CUDA