diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-01-07 13:26:02 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-07 13:26:02 -0800 |
commit | 6cc392e3b0989744c7b16248b19b48bcbe54fa6a (patch) | |
tree | 1537305b998b6dfc6ed700bb0060947549e94689 /tensorflow/core/kernels/avgpooling_op_gpu.cu.cc | |
parent | 3ffa307e49e5b150934a71386194d7ed621e3e98 (diff) |
some linting fixes to changes brought in from the public.
Change: 111621725
Diffstat (limited to 'tensorflow/core/kernels/avgpooling_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/avgpooling_op_gpu.cu.cc | 20 |
1 files changed, 7 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc index 15aa7bad41..fa54f6c42d 100644 --- a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { @@ -36,12 +37,6 @@ DEFINE_GPU_KERNELS(float) #undef DEFINE_GPU_KERNELS -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - -static const int CAFFE_CUDA_NUM_THREADS = 1024; - template <typename dtype> __global__ void AvePoolBackwardNHWC(const int nthreads, const dtype* const top_diff, const int num, @@ -93,13 +88,12 @@ bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, const int pad_l, T* const bottom_diff, const GPUDevice& d) { int x_size = num * height * width * channels; - int thread_per_block = - std::min(CAFFE_CUDA_NUM_THREADS, d.maxCudaThreadsPerMultiProcessor()); - int block_count = (x_size + thread_per_block - 1) / thread_per_block; - AvePoolBackwardNHWC<T><<<block_count, thread_per_block, 0, d.stream()>>>( - x_size, top_diff, num, height, width, channels, pooled_height, - pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_t, - bottom_diff); + CudaLaunchConfig config = GetCudaLaunchConfig(x_size, d); + AvePoolBackwardNHWC< + T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + config.virtual_thread_count, top_diff, num, height, width, channels, + pooled_height, pooled_width, kernel_h, kernel_w, stride_h, stride_w, + pad_t, pad_t, bottom_diff); return d.ok(); } |