diff options
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/maxpooling_op_gpu.cu.cc | 40 |
1 files changed, 29 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index 26f5274804..d96b844383 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -29,6 +29,15 @@ limitations under the License. namespace tensorflow { namespace { +template <bool propagate_nans, typename dtype> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) { + if (propagate_nans) { + return !(a <= b); + } else { + return a > b; + } +} + // This is Yangqing's custom kernel for the maxpooling operation. There are // three functions: MaxPoolForwardNCHW and MaxPoolForwardNHWC are the two // forward functions, dealing with the forward case. MaxPoolBackward is the @@ -51,7 +60,7 @@ namespace { // const int output_size = batch * channels * pooled_height * pooled_width; // MaxPoolForwardNCHW<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, // kThreadsPerBlock, 0, cuda_stream>>>(...); -template <typename dtype> +template <bool propagate_nans, typename dtype> __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, const int channels, const int height, const int width, const int pooled_height, @@ -77,7 +86,7 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int idx = c * height * width + h * width + w; - if (bottom_data_n[idx] > maxval) { + if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) { maxidx = idx; maxval = bottom_data_n[idx]; } @@ -126,7 +135,7 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C( } } -template <typename dtype> +template <bool propagate_nans, typename dtype> __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, const int height, const int width, const int channels, const int pooled_height, @@ -153,7 +162,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int idx = (h * width + w) * channels + c; - if (bottom_data_n[idx] > maxval) { + if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) { maxidx = idx; maxval = bottom_data_n[idx]; } @@ -390,15 +399,24 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()( const int channels, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, T* top_data, - int64* mask, const Eigen::GpuDevice& d) { + int64* mask, const Eigen::GpuDevice& d, bool propagate_nans) { const int kThreadsPerBlock = 1024; const int output_size = batch * channels * pooled_height * pooled_width; - - MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>>( - output_size, bottom_data, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_data, mask); + if (propagate_nans) { + MaxPoolForwardNHWC<true> + <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>> + (output_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_data, mask); + } else { + MaxPoolForwardNHWC<false> + <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>> + (output_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_data, mask); + } return d.ok(); } |