aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-15 14:51:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-15 15:05:56 -0800
commitd2693c8a70567cc78b2e8a9ac8020d321620ca83 (patch)
tree6fe62b4b17967aa8a774801eb4de1de756ed4fbc /tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
parentcace91e815242074852606367e6e2a09f09eedd2 (diff)
Increase the block size of SwapDimension1And2InTensor3UsingTiles from 32 to 128 threads.
This increases the parallelism in this kernel, yielding a nice speedup. Change: 139250945
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_gpu_3.cu.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc68
1 files changed, 37 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index dca0073a9a..d7f6923f17 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <algorithm>
#include <array>
+#include "cuda/include/cuda.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
@@ -176,20 +177,34 @@ __global__ void SwapDimension1And2InTensor3(int nthreads, const T* input,
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
-// TileSize could be arbitrary. But for best performance, it is better to be
-// the same as number of threads in a warp, which is 32 for almost all GPU
-// architectures.
-template <typename T, int TileSize>
+//
+// Each thread block operates on a single tile, a square of dimensions TileSize
+// x TileSize. We require that the thread block's X dimension equals TileSize,
+// and its Y dimension equals NumSubTiles.
+//
+// For best performance, you should probably set TileSize equal to the number of
+// threads in a warp (32 in nvidia GPUs). With a TileSize of 32, NumSubTiles ==
+// 4 or 8 seems to get the best performance on K40 GPUs.
+template <typename T, int TileSize, int NumSubTiles>
__global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
Dimension<3> input_dims,
T* output) {
// One extra line in the inner dimension to avoid share memory bank conflict.
__shared__ T shared_memory_tile[TileSize][TileSize + 1];
+ static_assert(TileSize % NumSubTiles == 0,
+ "TileSize must be divisible by NumSubTiles");
+ eigen_assert(blockDim.x == TileSize);
+ eigen_assert(blockDim.y == NumSubTiles);
+ eigen_assert(blockDim.z == 1);
+ eigen_assert(gridDim.y == 1);
+ eigen_assert(gridDim.z == 1);
+
+ // We break down the tile into NumSubTiles groups, so each thread processes
+ // kSubTileSize elements (except at the edges of the input).
+ const int kSubTileSize = TileSize / NumSubTiles;
+
int x = threadIdx.x;
- if (x >= TileSize) {
- return;
- }
Dimension<3> output_dims = {
input_dims[0], input_dims[2], input_dims[1],
@@ -222,18 +237,13 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
}
int input_flat_index = input_origin_flat_index + x;
+ int y_start = static_cast<int>(threadIdx.y) * kSubTileSize;
// Load the data from input memory to the shared memory tile.
- if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) {
-#pragma unroll
- for (int y = 0; y < TileSize; y++) {
- shared_memory_tile[y][x] = input[input_flat_index];
- input_flat_index += input_dims[2];
- }
- } else if (x < tile_width) {
- for (int y = 0; y < tile_height; y++) {
- shared_memory_tile[y][x] = input[input_flat_index];
- input_flat_index += input_dims[2];
+ if (x < tile_width) {
+ int y_end = min(y_start + kSubTileSize, tile_height);
+ for (int y = y_start; y < y_end; y++) {
+ shared_memory_tile[y][x] = input[input_flat_index + y * input_dims[2]];
}
}
@@ -254,16 +264,10 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
int output_flat_index = output_origin_flat_index + x;
// Load the data from the shared memory tile to the output memory.
- if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) {
-#pragma unroll
- for (int y = 0; y < TileSize; y++) {
- output[output_flat_index] = shared_memory_tile[x][y];
- output_flat_index += output_dims[2];
- }
- } else if (x < tile_height) {
- for (int y = 0; y < tile_width; y++) {
- output[output_flat_index] = shared_memory_tile[x][y];
- output_flat_index += output_dims[2];
+ if (x < tile_height) {
+ int y_end = min(y_start + kSubTileSize, tile_width);
+ for (int y = y_start; y < y_end; y++) {
+ output[output_flat_index + y * output_dims[2]] = shared_memory_tile[x][y];
}
}
}
@@ -421,17 +425,19 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
input_dims[2] >= kMinDimensionToUseTiles);
if (use_tiles) {
- // The tile-size can be chosen to be arbitrary number. But it is better to
- // be the same as number of threads in a warp, which is 32.
+ // We get best performance when TileSize is the number of threads in a warp
+ // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
+ // threads.
static const int TileSize = 32;
+ static const int NumSubTiles = 8;
Dimension<3> input_dims_in_tiles = {
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
(input_dims[2] + TileSize - 1) / TileSize,
};
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
input_dims_in_tiles[2];
- SwapDimension1And2InTensor3UsingTiles<
- T, TileSize><<<total_tiles_count, TileSize, 0, d.stream()>>>(
+ SwapDimension1And2InTensor3UsingTiles<T, TileSize, NumSubTiles><<<
+ total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>(
input, input_dims, output);
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];