aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_gpu_3.cu.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc127
1 files changed, 123 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index 3d4670c9ba..9083626fbf 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -272,6 +272,88 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
}
}
+// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor
+// when only one of the dimension sizes is smaller than 16,
+// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
+//
+// small_dim = the_smaller_dimension_size
+// large_dim = the_larger_dimension_size
+// tile_num_per_block = blockDim.x
+// kTileLength = small_dim
+//
+// Each thread block operates on a single rectangle tile, where its width is
+// kTileLength (we currently set it to 64) and its height is small_dim,
+// We set the thread block's X dimension to be tile_num_per_block, and its Y
+// and Z to be one.
+template <typename T, int ShmemSize, bool SmallDim2>
+__global__ void SwapDimension1And2InTensor3SmallDim(const T* input,
+ int batch_per_block,
+ Dimension<3> input_dims,
+ T* output) {
+ // TODO(yangzihao) avoid share memory bank conflict.
+ __shared__ T shared_memory_tile[ShmemSize];
+
+ eigen_assert(blockDim.y == 1);
+ eigen_assert(blockDim.z == 1);
+ eigen_assert(gridDim.z == 1);
+
+ int block_offset = blockIdx.x * blockDim.x;
+
+ int x = threadIdx.x;
+ int tile_height = blockDim.x;
+
+ // Get tile height, width, and thread/block origin indices.
+ int small_dim = SmallDim2 ? input_dims[2] : input_dims[1];
+ int large_dim = SmallDim2 ? input_dims[1] : input_dims[2];
+
+ int global_offset = small_dim * large_dim * (blockIdx.y * batch_per_block) +
+ (SmallDim2 ? block_offset * small_dim : block_offset);
+ if (global_offset >= (input_dims[0] * input_dims[1] * input_dims[2])) return;
+
+ for (int batch = 0; batch < batch_per_block; ++batch) {
+ int block_origin_idx =
+ small_dim * large_dim * (blockIdx.y * batch_per_block + batch);
+ int thread_origin_idx =
+ block_origin_idx +
+ (SmallDim2 ? block_offset * small_dim : block_offset) + x;
+
+ if (block_offset + blockDim.x > large_dim) {
+ tile_height = large_dim - block_offset;
+ }
+
+ __syncthreads();
+
+ // Load a continuous memory region to shared memory tile.
+ if (x < tile_height) {
+ for (int y = 0; y < small_dim; y++) {
+ int shmem_index =
+ SmallDim2 ? (x + y * tile_height) : (x * small_dim + y);
+ shared_memory_tile[shmem_index] =
+ ldg(input + thread_origin_idx +
+ y * (SmallDim2 ? tile_height : large_dim));
+ }
+ }
+
+ __syncthreads();
+
+ // Get block origin index for output array.
+ int output_block_offset = block_origin_idx;
+ int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim;
+ int output_block_origin_idx = output_block_offset + output_block_idx;
+
+ // Store the tranposed memory region in shared memory to device.
+ if (x < tile_height) {
+ for (int y = 0; y < small_dim; y++) {
+ int output_idx = output_block_origin_idx + x +
+ y * (SmallDim2 ? large_dim : tile_height);
+ int shmem_index =
+ SmallDim2 ? (x * small_dim + y) : (x + y * tile_height);
+ output[output_idx] = shared_memory_tile[shmem_index];
+ }
+ }
+ }
+}
+
// A Cuda custom kernel that convert input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T, int NDIMS>
@@ -420,25 +502,62 @@ template <typename T>
void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
const Dimension<3>& input_dims, T* output) {
// If both dimensions are not trivial, use tiles for the actual swapping.
+ // If one dimension is trivial, use SmallDim kernel for swapping.
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
static const int kMinDimensionToUseTiles = 16;
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles);
+ bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles &&
+ input_dims[2] < kMinDimensionToUseTiles)) ||
+ ((input_dims[1] < kMinDimensionToUseTiles &&
+ input_dims[2] >= kMinDimensionToUseTiles));
+ static const int NumSubTiles = 8;
+
if (use_tiles) {
- // We get best performance when TileSize is the number of threads in a warp
- // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
- // threads.
static const int TileSize = 32;
- static const int NumSubTiles = 8;
Dimension<3> input_dims_in_tiles = {
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
(input_dims[2] + TileSize - 1) / TileSize,
};
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
+ // We get best performance when TileSize is the number of threads in a warp
+ // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
+ // threads.
SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles><<<
total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
input, input_dims, output);
+ } else if (use_small_dim) {
+ // When only one of the dimensions is smaller than kMinDimensionToUseTiles,
+ // we use one block to process a rectangle region with the size of
+ // kTileLength * small_dim. We found that when set kTileLength to 64 on
+ // TitanX Maxwell GPU, it achieves the best performance.
+ // large_dim
+ // +---------------...--------+
+ // | | | |
+ // small_dim | | ... | |
+ // | | | |
+ // +--------------...---------+
+ // \----- ------/ \- -/
+ // V V
+ // kTileLength(tile_height) tile_height
+ static const int kTileLength = 64;
+ static const int kGridDimY = 65535;
+ int large_dim = std::max(input_dims[2], input_dims[1]);
+ int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength;
+ int grid_dim_y = std::min(input_dims[0], kGridDimY);
+ int batch_per_block = (input_dims[0] + grid_dim_y - 1) / grid_dim_y;
+ if (input_dims[2] < input_dims[1]) {
+ SwapDimension1And2InTensor3SmallDim<
+ T, kTileLength * kMinDimensionToUseTiles, true>
+ <<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
+ d.stream()>>>(input, batch_per_block, input_dims, output);
+ } else {
+ SwapDimension1And2InTensor3SmallDim<
+ T, kTileLength * kMinDimensionToUseTiles, false>
+ <<<dim3(tile_num_per_block, grid_dim_y), kTileLength, 0,
+ d.stream()>>>(input, batch_per_block, input_dims, output);
+ }
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);