diff options
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/maxpooling_op_gpu.cu.cc | 40 |
1 files changed, 11 insertions, 29 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index d96b844383..26f5274804 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -29,15 +29,6 @@ limitations under the License. namespace tensorflow { namespace { -template <bool propagate_nans, typename dtype> -EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) { - if (propagate_nans) { - return !(a <= b); - } else { - return a > b; - } -} - // This is Yangqing's custom kernel for the maxpooling operation. There are // three functions: MaxPoolForwardNCHW and MaxPoolForwardNHWC are the two // forward functions, dealing with the forward case. MaxPoolBackward is the @@ -60,7 +51,7 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) { // const int output_size = batch * channels * pooled_height * pooled_width; // MaxPoolForwardNCHW<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, // kThreadsPerBlock, 0, cuda_stream>>>(...); -template <bool propagate_nans, typename dtype> +template <typename dtype> __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, const int channels, const int height, const int width, const int pooled_height, @@ -86,7 +77,7 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int idx = c * height * width + h * width + w; - if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) { + if (bottom_data_n[idx] > maxval) { maxidx = idx; maxval = bottom_data_n[idx]; } @@ -135,7 +126,7 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C( } } -template <bool propagate_nans, typename dtype> +template <typename dtype> __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, const int height, const int width, const int channels, const int pooled_height, @@ -162,7 +153,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int idx = (h * width + w) * channels + c; - if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) { + if (bottom_data_n[idx] > maxval) { maxidx = idx; maxval = bottom_data_n[idx]; } @@ -399,24 +390,15 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()( const int channels, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, T* top_data, - int64* mask, const Eigen::GpuDevice& d, bool propagate_nans) { + int64* mask, const Eigen::GpuDevice& d) { const int kThreadsPerBlock = 1024; const int output_size = batch * channels * pooled_height * pooled_width; - if (propagate_nans) { - MaxPoolForwardNHWC<true> - <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>> - (output_size, bottom_data, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_data, mask); - } else { - MaxPoolForwardNHWC<false> - <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, - kThreadsPerBlock, 0, d.stream()>>> - (output_size, bottom_data, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, - top_data, mask); - } + + MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + output_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_data, mask); return d.ok(); } |