aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar ekelsen <ekelsen@gmail.com>2018-04-04 13:53:36 -0700
committerGravatar GitHub <noreply@github.com>2018-04-04 13:53:36 -0700
commitc1c819b28476d72c1f086fc4e78ff7f013c225ce (patch)
tree5573cd4f6fbf66d7f70705400c03f37222794778
parent5d33c1e49178aedbb459da7ce58eca710102c06b (diff)
parent15f3b920ad7eb7fcca3afee14d16049db2046d4b (diff)
Merge pull request #17027 from nluehr/shared_complex_fix
Fix __shared__ complex<T> undefined behavior
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h37
1 files changed, 33 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 9237fa51d8..0de2ebb590 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -244,6 +244,33 @@ __global__ void RowReduceKernel(
if (row < num_rows && lane == 0) out[row] = sum;
}
+template <typename T1>
+struct storage_type {
+ T1 val;
+ __host__ __device__ storage_type() {}
+ __host__ __device__ operator T1() { return val; }
+ __host__ __device__ storage_type<T1>& operator=(const T1& in) {
+ val = in;
+ return *this;
+ }
+};
+
+template <typename T2>
+struct storage_type<std::complex<T2>> {
+ T2 real;
+ T2 imag;
+ __host__ __device__ storage_type() {}
+ __host__ __device__ operator std::complex<T2>() {
+ return std::complex<T2>(real, imag);
+ }
+ __host__ __device__ storage_type<std::complex<T2>>& operator=(
+ const std::complex<T2>& in) {
+ real = in.real();
+ imag = in.imag();
+ return *this;
+ }
+};
+
// Works only if there are <= 16 columns
// each warps sums over multiple rows at once
template <typename T, typename outT, typename Op>
@@ -268,7 +295,7 @@ __global__ void ColumnReduceMax16ColumnsKernel(
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
- __shared__ value_type partial_sums[32 * 33];
+ __shared__ storage_type<value_type> partial_sums[32 * 33];
row += rows_per_warp * gridDim.y * blockDim.y;
for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
@@ -294,7 +321,8 @@ __global__ void ColumnReduceMax16ColumnsKernel(
if (blockDim.y > 1) {
for (int row = 1; row < blockDim.y; ++row) {
- s = op(s, partial_sums[threadIdx.x * 33 + row]);
+ value_type t = partial_sums[threadIdx.x * 33 + row];
+ s = op(s, t);
}
}
@@ -316,7 +344,7 @@ __global__ void ColumnReduceKernel(
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
- __shared__ value_type partial_sums[32 * 33];
+ __shared__ storage_type<value_type> partial_sums[32 * 33];
row += gridDim.y * blockDim.y;
@@ -347,7 +375,8 @@ __global__ void ColumnReduceKernel(
min(blockDim.y, num_rows - blockIdx.y * blockDim.y);
for (int row = 1; row < numRowsThisBlock; ++row) {
- s = op(s, partial_sums[threadIdx.x * 33 + row]);
+ value_type t = partial_sums[threadIdx.x * 33 + row];
+ s = op(s, t);
}
out[col * gridDim.y + blockIdx.y] = s;