aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-01 11:09:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-01 12:17:53 -0700
commit6172351b81af76d0b819fea6bb478cbd4016d6c2 (patch)
tree81881e0c77ad2af471d94de7c235c8a94bf811ae
parent04d0920e464220a0c94f2c959e5db2dc8631b90d (diff)
Speed up SwapDimension1And2InTensor3UsingTiles.
For all but the edge tiles, we can unroll the loops that read and write from our shared memory tile. Change: 137853968
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc19
1 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index d9882764fb..dca0073a9a 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -221,9 +221,16 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize;
}
+ int input_flat_index = input_origin_flat_index + x;
+
// Load the data from input memory to the shared memory tile.
- if (x < tile_width) {
- int input_flat_index = input_origin_flat_index + x;
+ if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) {
+#pragma unroll
+ for (int y = 0; y < TileSize; y++) {
+ shared_memory_tile[y][x] = input[input_flat_index];
+ input_flat_index += input_dims[2];
+ }
+ } else if (x < tile_width) {
for (int y = 0; y < tile_height; y++) {
shared_memory_tile[y][x] = input[input_flat_index];
input_flat_index += input_dims[2];
@@ -247,7 +254,13 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
int output_flat_index = output_origin_flat_index + x;
// Load the data from the shared memory tile to the output memory.
- if (x < tile_height) {
+ if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) {
+#pragma unroll
+ for (int y = 0; y < TileSize; y++) {
+ output[output_flat_index] = shared_memory_tile[x][y];
+ output_flat_index += output_dims[2];
+ }
+ } else if (x < tile_height) {
for (int y = 0; y < tile_width; y++) {
output[output_flat_index] = shared_memory_tile[x][y];
output_flat_index += output_dims[2];