diff options
author | 2016-11-01 11:09:06 -0800 | |
---|---|---|
committer | 2016-11-01 12:17:53 -0700 | |
commit | 6172351b81af76d0b819fea6bb478cbd4016d6c2 (patch) | |
tree | 81881e0c77ad2af471d94de7c235c8a94bf811ae | |
parent | 04d0920e464220a0c94f2c959e5db2dc8631b90d (diff) |
Speed up SwapDimension1And2InTensor3UsingTiles.
For all but the edge tiles, we can unroll the loops that read and write
from our shared memory tile.
Change: 137853968
-rw-r--r-- | tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index d9882764fb..dca0073a9a 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -221,9 +221,16 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input, tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize; } + int input_flat_index = input_origin_flat_index + x; + // Load the data from input memory to the shared memory tile. - if (x < tile_width) { - int input_flat_index = input_origin_flat_index + x; + if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) { +#pragma unroll + for (int y = 0; y < TileSize; y++) { + shared_memory_tile[y][x] = input[input_flat_index]; + input_flat_index += input_dims[2]; + } + } else if (x < tile_width) { for (int y = 0; y < tile_height; y++) { shared_memory_tile[y][x] = input[input_flat_index]; input_flat_index += input_dims[2]; @@ -247,7 +254,13 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input, int output_flat_index = output_origin_flat_index + x; // Load the data from the shared memory tile to the output memory. - if (x < tile_height) { + if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) { +#pragma unroll + for (int y = 0; y < TileSize; y++) { + output[output_flat_index] = shared_memory_tile[x][y]; + output_flat_index += output_dims[2]; + } + } else if (x < tile_height) { for (int y = 0; y < tile_width; y++) { output[output_flat_index] = shared_memory_tile[x][y]; output_flat_index += output_dims[2]; |