aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bias_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/bias_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc23
1 files changed, 14 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index ddc2d457b0..42f3db1d79 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -173,15 +173,20 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// Accumulate the results in the shared memory into the first element.
// No syncthreads is needed since this is only in the same warp.
int32 thread_index = threadIdx.x;
- if (thread_index < 16) s_data[thread_index] += s_data[thread_index + 16];
- if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
- if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
- if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
- if (thread_index < 1) s_data[thread_index] += s_data[thread_index + 1];
-
- // The first thread writes out the accumulated result to the global location.
- if (thread_index == 0) {
- CudaAtomicAdd(bias_backprop + bias_index, T(s_data[0]));
+ if (thread_index < 16) {
+ s_data[thread_index] += s_data[thread_index + 16];
+ __syncwarp(0xFFFF);
+ if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
+ __syncwarp(0xFF);
+ if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
+ __syncwarp(0xF);
+ if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
+ __syncwarp(0x3);
+ if (thread_index == 0) {
+ T val = T(s_data[0] + s_data[1]);
+ // The first thread writes out the accumulated result to global location.
+ CudaAtomicAdd(bias_backprop + bias_index, val);
+ }
}
}