aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-06-03 06:33:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-03 07:41:08 -0700
commit9bedadceab3e126684494e6e6a8103ccab9d90c7 (patch)
tree9cedee72446581781cd49e81f7e26f1f90b6803c /tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
parent536e7caf86f3dd70518f6d45c2ab7ed19747c0c1 (diff)
Enable fp16 for most of the pooling ops (MaxPool, AvgPool, associated
gradients, some variants etc.). Change: 123967117
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc72
1 files changed, 64 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 1bdca42f4e..91b50b1e11 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -110,7 +110,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
- dtype maxval = -FLT_MAX;
+ dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
for (int h = hstart; h < hend; ++h) {
@@ -149,7 +149,7 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
- dtype maxval = -FLT_MAX;
+ dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
for (int h = hstart; h < hend; ++h) {
@@ -165,8 +165,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// Atomically accumulate the bottom diff. The index could still be
// uninitialized, if all the bottom_data are NaN.
if (maxidx != -1) {
- atomicAdd(bottom_diff + n * height * width * channels + maxidx,
- top_diff[index]);
+ CudaAtomicAdd(bottom_diff + n * height * width * channels + maxidx,
+ top_diff[index]);
}
}
}
@@ -185,8 +185,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool input.
// This is equal to H*W*C.
// bottom_diff: the gradient with respect to the input.
-// This function relies on atomicAdd to avoid race conditions. Also, before the
-// kernel is run, you will need to make sure that bottom_diff is filled with
+// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
+// the kernel is run, you will need to make sure that bottom_diff is filled with
// zero first.
template <typename dtype>
__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
@@ -194,8 +194,8 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
const int bottom_offset, dtype* bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int image_id = (index / top_offset);
- atomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
- top_diff[index]);
+ CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
+ top_diff[index]);
}
}
@@ -219,6 +219,23 @@ bool MaxPoolForwardWithOptionalArgmax(
return d.ok();
}
+bool MaxPoolForwardWithOptionalArgmax(
+ const Eigen::half* bottom_data, const int batch, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ const int output_size = batch * channels * pooled_height * pooled_width;
+
+ MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ output_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_data, mask);
+ return d.ok();
+}
+
bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
const int height, const int width,
const int channels, const int pooled_height,
@@ -243,6 +260,30 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
return d.ok();
}
+bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l,
+ const Eigen::half* top_diff, Eigen::half* bottom_diff,
+ const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ const int bottom_size = batch * channels * height * width;
+ const int top_size = batch * channels * pooled_height * pooled_width;
+
+ SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
+
+ MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
+ kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ top_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_diff, bottom_diff);
+ return d.ok();
+}
+
bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
const float* top_diff, const int64* mask,
const int top_offset, const int bottom_offset,
@@ -256,12 +297,27 @@ bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
return d.ok();
}
+bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
+ const Eigen::half* top_diff, const int64* mask,
+ const int top_offset, const int bottom_offset,
+ Eigen::half* bottom_diff,
+ const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
+ MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff);
+ return d.ok();
+}
+
typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_KERNELS(T) \
template struct functor::SpatialMaxPooling<GPUDevice, T>;
DEFINE_GPU_KERNELS(float)
+DEFINE_GPU_KERNELS(Eigen::half)
#undef DEFINE_GPU_KERNELS