aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc40
1 files changed, 29 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 26f5274804..d96b844383 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -29,6 +29,15 @@ limitations under the License.
namespace tensorflow {
namespace {
+template <bool propagate_nans, typename dtype>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) {
+ if (propagate_nans) {
+ return !(a <= b);
+ } else {
+ return a > b;
+ }
+}
+
// This is Yangqing's custom kernel for the maxpooling operation. There are
// three functions: MaxPoolForwardNCHW and MaxPoolForwardNHWC are the two
// forward functions, dealing with the forward case. MaxPoolBackward is the
@@ -51,7 +60,7 @@ namespace {
// const int output_size = batch * channels * pooled_height * pooled_width;
// MaxPoolForwardNCHW<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
// kThreadsPerBlock, 0, cuda_stream>>>(...);
-template <typename dtype>
+template <bool propagate_nans, typename dtype>
__global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
const int channels, const int height,
const int width, const int pooled_height,
@@ -77,7 +86,7 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int idx = c * height * width + h * width + w;
- if (bottom_data_n[idx] > maxval) {
+ if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) {
maxidx = idx;
maxval = bottom_data_n[idx];
}
@@ -126,7 +135,7 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
}
}
-template <typename dtype>
+template <bool propagate_nans, typename dtype>
__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
const int height, const int width,
const int channels, const int pooled_height,
@@ -153,7 +162,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int idx = (h * width + w) * channels + c;
- if (bottom_data_n[idx] > maxval) {
+ if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) {
maxidx = idx;
maxval = bottom_data_n[idx];
}
@@ -390,15 +399,24 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
const int channels, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_t, const int pad_l, T* top_data,
- int64* mask, const Eigen::GpuDevice& d) {
+ int64* mask, const Eigen::GpuDevice& d, bool propagate_nans) {
const int kThreadsPerBlock = 1024;
const int output_size = batch * channels * pooled_height * pooled_width;
-
- MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
- kThreadsPerBlock, 0, d.stream()>>>(
- output_size, bottom_data, height, width, channels, pooled_height,
- pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
- top_data, mask);
+ if (propagate_nans) {
+ MaxPoolForwardNHWC<true>
+ <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>
+ (output_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_data, mask);
+ } else {
+ MaxPoolForwardNHWC<false>
+ <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>
+ (output_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_data, mask);
+ }
return d.ok();
}