aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-07 14:17:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-07 15:33:32 -0700
commitef2f8891ad409e41f4f9b8e9cfd86b519adb6da6 (patch)
treeb11a34ae900f98fd59a75e883efaa6d6706d82e5 /tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
parentf2d09548722a98278fa9efd1ab33c0c87f9d58bd (diff)
Fix race due to unsafe buffer forwarding in maxpooling second order gradients added in #6664.
Re-enable previously flaky tests. Clean up a few minor things in maxpooling_op_gpu.cu.cc Change: 152550050
Diffstat (limited to 'tensorflow/core/kernels/maxpooling_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc9
1 files changed, 3 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 32b210ecb7..e3a57d2f28 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -70,7 +70,7 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
- dtype maxval = -FLT_MAX;
+ dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * channels * height * width;
for (int h = hstart; h < hend; ++h) {
@@ -312,9 +312,6 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool output.
// This is equal to Hout*Wout*C.
// bottom_diff: the gradient of the gradient w.r.t. output.
-// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
-// the kernel is run, you will need to make sure that bottom_diff is filled with
-// zero first.
template <typename dtype>
__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
const int64* mask, const int top_offset,
@@ -357,12 +354,12 @@ bool MaxPoolBackwardNoMask<T>::operator()(
const int stride_w, const int pad_t, const int pad_l, const T* top_diff,
T* bottom_diff, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
- const int bottom_size = batch * channels * height * width;
- const int top_size = batch * channels * pooled_height * pooled_width;
+ const int bottom_size = batch * channels * height * width;
SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
+ const int top_size = batch * channels * pooled_height * pooled_width;
MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(