aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2018-06-01 00:30:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 00:32:18 -0700
commit73e5438b725b46e745e6e910c6557b51a321c70f (patch)
tree1752bf4a844619ccc1b1dcdc99a7185300311697 /tensorflow
parent961a39346d8be33cff473f1e81498b887c155070 (diff)
Remove the constructor in shared memory.
PiperOrigin-RevId: 198837256
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc8
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h12
2 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a2e7342b04..a5fa48f85e 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -247,7 +247,13 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
constexpr int ReadRowPerPass = NumThreads / TileSizeJ;
constexpr int WriteRowPerPass = NumThreads / TileSizeI;
// One extra line in the inner dimension to avoid share memory bank conflict.
- __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
+ // This is to mimic the following, but no constructor of T can be invoked.
+ // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
+ __shared__ __align__(
+ alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
+ typedef T(*SharedMemoryTile)[TileSizeJ + 1];
+ SharedMemoryTile shared_memory_tile =
+ reinterpret_cast<SharedMemoryTile>(shared_mem_raw);
int x = threadIdx.x;
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 0de2ebb590..6655084045 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -295,7 +295,11 @@ __global__ void ColumnReduceMax16ColumnsKernel(
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
- __shared__ storage_type<value_type> partial_sums[32 * 33];
+ // This is the mimic the following, but without any constructors:
+ // __shared__ storage_type<value_type> partial_sums[32 * 33];
+ __shared__ __align__(
+ alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];
+ value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
row += rows_per_warp * gridDim.y * blockDim.y;
for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
@@ -344,7 +348,11 @@ __global__ void ColumnReduceKernel(
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
- __shared__ storage_type<value_type> partial_sums[32 * 33];
+ // This is to mimic the following, but without constructors:
+ // __shared__ storage_type<value_type> partial_sums[32 * 33];
+ __shared__ __align__(
+ alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];
+ value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
row += gridDim.y * blockDim.y;