diff options
author | 2016-03-10 13:23:33 -0800 | |
---|---|---|
committer | 2016-03-10 14:41:42 -0800 | |
commit | 7cc6ba76176b84c828dd6b40ec5d2bc0d481f46a (patch) | |
tree | 4a30d5ed11487820cf8492d1da6d8530f02049dd /tensorflow/core/kernels/bias_op_gpu.cu.cc | |
parent | fd839dc6e22fd5e4766db237808828a9ff01a19c (diff) |
Improve the BiasGrad for NCHW using less shared memory and better memory
efficiency.
With GoogleNet V1, time spent in BiasGrad in ms:
Before After Improvement
GoogleNet V1 19.70 13.14 49.93%
Change: 116901889
Diffstat (limited to 'tensorflow/core/kernels/bias_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/bias_op_gpu.cu.cc | 85 |
1 files changed, 50 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 5c90c3715f..79344e7975 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -118,31 +118,51 @@ __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads, } template <typename T> -__global__ void BiasGradNCHW_SharedAtomics(int32 nthreads, - const T* output_backprop, - T* bias_backprop, int32 bias_size, - int32 image_size, - int32 shared_replicas) { - T* s_data = reinterpret_cast<T*>(s_buf); - int32 s_data_size = bias_size * shared_replicas; +__global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, + T* bias_backprop, int32 batch, + int32 bias_size, int32 image_size, + int group_size) { + // Initialize the shared memory. + __shared__ T s_data[32]; + int32 s_data_size = sizeof(s_data) / sizeof(T); for (int32 index = threadIdx.x; index < s_data_size; index += blockDim.x) { s_data[index] = 0; } __syncthreads(); - for (int32 index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; - index += blockDim.x * gridDim.x) { - int32 index2 = index / image_size; - int32 bias_slot_index = index2 % bias_size; - int32 bias_slot_offset = index % shared_replicas; - int32 bias_offset = bias_slot_index * shared_replicas + bias_slot_offset; - CudaAtomicAdd(s_data + bias_offset, ldg(output_backprop + index)); + // Accumulate all the values within this thread. They all have the same bias + // index. + int32 bias_index = blockIdx.x % bias_size; + int32 group_index = blockIdx.x / bias_size; + int32 total_count = batch * image_size; + T sum = 0; + for (int32 index = group_index * blockDim.x + threadIdx.x; + index < total_count; index += blockDim.x * group_size) { + int32 image_offset = index % image_size; + int32 batch = index / image_size; + T val = ldg(output_backprop + + (batch * bias_size + bias_index) * image_size + image_offset); + sum += val; } + + // Write the accumulated sum in this thread to the shared memory. Each thread + // shifts their write location to avoid bank conflict. + int bias_offset = threadIdx.x % 32; + CudaAtomicAdd(s_data + bias_offset, sum); __syncthreads(); - for (int32 index = threadIdx.x; index < s_data_size; index += blockDim.x) { - int bias_slot_index = index / shared_replicas; - CudaAtomicAdd(bias_backprop + bias_slot_index, s_data[index]); + // Accumulate the results in the shared memory into the first element. + // No syncthreads is needed since this is only in the same warp. + int32 thread_index = threadIdx.x; + if (thread_index < 16) s_data[thread_index] += s_data[thread_index + 16]; + if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8]; + if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4]; + if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2]; + if (thread_index < 1) s_data[thread_index] += s_data[thread_index + 1]; + + // The first thread writes out the accumulated result to the global location. + if (thread_index == 0) { + CudaAtomicAdd(bias_backprop + bias_index, s_data[0]); } } @@ -154,24 +174,13 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop, const int32 bias_size = channel; const int32 image_size = height * width; const int32 total_count = batch * bias_size * image_size; + static constexpr int32 kWarpSize = 32; CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); const int max_shared_memory_size = d.sharedMemPerBlock() / 2; - int32 shared_memory_size = bias_size * sizeof(T); - int shared_replicas = 1; - if (data_format == FORMAT_NCHW) { - // For NCHW, the reduction in the HW dimensions all go to the same locaiton, - // which causes a lot of bank conflicts. So having a number of them can - // improve the performance. But we also want to limit their usage so the - // warp occupancy does not decrease. - if (shared_memory_size <= max_shared_memory_size) { - // We need enough shared memory to avoid bank conflict. But not too much - // so that it would reduce occupancy. - static constexpr int kMaxSharedReplicas = 8; - shared_replicas = std::min(kMaxSharedReplicas, - max_shared_memory_size / shared_memory_size); - shared_memory_size *= shared_replicas; - } + int32 shared_memory_size = 0; + if (data_format == FORMAT_NHWC) { + shared_memory_size = bias_size * sizeof(T); } // Check if we have enough shared memory. if (shared_memory_size <= max_shared_memory_size) { @@ -181,10 +190,16 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop, d.stream()>>>(total_count, output_backprop, bias_backprop, bias_size); } else { + // Round up the block count to multiple of bias_size. + int group_size = (config.block_count + bias_size - 1) / bias_size; + config.block_count = group_size * bias_size; + if (config.thread_per_block < kWarpSize) { + config.thread_per_block = kWarpSize; + } BiasGradNCHW_SharedAtomics< - T><<<config.block_count, config.thread_per_block, shared_memory_size, - d.stream()>>>(total_count, output_backprop, bias_backprop, - bias_size, image_size, shared_replicas); + T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + output_backprop, bias_backprop, batch, bias_size, image_size, + group_size); } } else { // Note that even if we don't have enough shared memory to fit the entire |