aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc19
1 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index d9882764fb..dca0073a9a 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -221,9 +221,16 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize;
}
+ int input_flat_index = input_origin_flat_index + x;
+
// Load the data from input memory to the shared memory tile.
- if (x < tile_width) {
- int input_flat_index = input_origin_flat_index + x;
+ if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) {
+#pragma unroll
+ for (int y = 0; y < TileSize; y++) {
+ shared_memory_tile[y][x] = input[input_flat_index];
+ input_flat_index += input_dims[2];
+ }
+ } else if (x < tile_width) {
for (int y = 0; y < tile_height; y++) {
shared_memory_tile[y][x] = input[input_flat_index];
input_flat_index += input_dims[2];
@@ -247,7 +254,13 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
int output_flat_index = output_origin_flat_index + x;
// Load the data from the shared memory tile to the output memory.
- if (x < tile_height) {
+ if (TF_PREDICT_TRUE(tile_height == TileSize && tile_width == TileSize)) {
+#pragma unroll
+ for (int y = 0; y < TileSize; y++) {
+ output[output_flat_index] = shared_memory_tile[x][y];
+ output_flat_index += output_dims[2];
+ }
+ } else if (x < tile_height) {
for (int y = 0; y < tile_width; y++) {
output[output_flat_index] = shared_memory_tile[x][y];
output_flat_index += output_dims[2];