diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-06-03 06:33:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-03 07:41:08 -0700 |
commit | 9bedadceab3e126684494e6e6a8103ccab9d90c7 (patch) | |
tree | 9cedee72446581781cd49e81f7e26f1f90b6803c /tensorflow/core/kernels/maxpooling_op_gpu.cu.cc | |
parent | 536e7caf86f3dd70518f6d45c2ab7ed19747c0c1 (diff) |
Enable fp16 for most of the pooling ops (MaxPool, AvgPool, associated
gradients, some variants etc.).
Change: 123967117
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/maxpooling_op_gpu.cu.cc | 72 |
1 files changed, 64 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index 1bdca42f4e..91b50b1e11 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -110,7 +110,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, int wend = min(wstart + kernel_w, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - dtype maxval = -FLT_MAX; + dtype maxval = Eigen::NumTraits<dtype>::lowest(); int maxidx = -1; const dtype* bottom_data_n = bottom_data + n * height * width * channels; for (int h = hstart; h < hend; ++h) { @@ -149,7 +149,7 @@ __global__ void MaxPoolBackwardNoMaskNHWC( int wend = min(wstart + kernel_w, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - dtype maxval = -FLT_MAX; + dtype maxval = Eigen::NumTraits<dtype>::lowest(); int maxidx = -1; const dtype* bottom_data_n = bottom_data + n * height * width * channels; for (int h = hstart; h < hend; ++h) { @@ -165,8 +165,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC( // Atomically accumulate the bottom diff. The index could still be // uninitialized, if all the bottom_data are NaN. if (maxidx != -1) { - atomicAdd(bottom_diff + n * height * width * channels + maxidx, - top_diff[index]); + CudaAtomicAdd(bottom_diff + n * height * width * channels + maxidx, + top_diff[index]); } } } @@ -185,8 +185,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC( // bottom_offset: the pre-computed per-image offset of the maxpool input. // This is equal to H*W*C. // bottom_diff: the gradient with respect to the input. -// This function relies on atomicAdd to avoid race conditions. Also, before the -// kernel is run, you will need to make sure that bottom_diff is filled with +// This function relies on CudaAtomicAdd to avoid race conditions. Also, before +// the kernel is run, you will need to make sure that bottom_diff is filled with // zero first. template <typename dtype> __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff, @@ -194,8 +194,8 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff, const int bottom_offset, dtype* bottom_diff) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int image_id = (index / top_offset); - atomicAdd(bottom_diff + image_id * bottom_offset + mask[index], - top_diff[index]); + CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index], + top_diff[index]); } } @@ -219,6 +219,23 @@ bool MaxPoolForwardWithOptionalArgmax( return d.ok(); } +bool MaxPoolForwardWithOptionalArgmax( + const Eigen::half* bottom_data, const int batch, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int output_size = batch * channels * pooled_height * pooled_width; + + MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + output_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_data, mask); + return d.ok(); +} + bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, const int height, const int width, const int channels, const int pooled_height, @@ -243,6 +260,30 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, return d.ok(); } +bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, + const Eigen::half* top_diff, Eigen::half* bottom_diff, + const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int bottom_size = batch * channels * height * width; + const int top_size = batch * channels * pooled_height * pooled_width; + + SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff); + + MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) / + kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + top_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_diff, bottom_diff); + return d.ok(); +} + bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, const float* top_diff, const int64* mask, const int top_offset, const int bottom_offset, @@ -256,12 +297,27 @@ bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, return d.ok(); } +bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, + const Eigen::half* top_diff, const int64* mask, + const int top_offset, const int bottom_offset, + Eigen::half* bottom_diff, + const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); + MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff); + return d.ok(); +} + typedef Eigen::GpuDevice GPUDevice; #define DEFINE_GPU_KERNELS(T) \ template struct functor::SpatialMaxPooling<GPUDevice, T>; DEFINE_GPU_KERNELS(float) +DEFINE_GPU_KERNELS(Eigen::half) #undef DEFINE_GPU_KERNELS |