aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-25 08:48:47 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-25 08:48:47 -0800
commit854f49bd43588c062b046384f239f64a3d819702 (patch)
treec2373bf71ef65ae4c116ea703947141281c1eace /tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
parent9c3043ff3bf31a6a81810b4ce9e87ef936f1f529 (diff)
TensorFlow: Upstream changes to git
Changes: - Updates to docs - Several changes for Python 3 compatibility - Added license headers Base CL: 108710566
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_gpu_3.cu.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc177
1 files changed, 163 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a40b3caefd..2d7e149c7b 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// A simple array that contains data that can be passed between CPU and GPU.
-template <typename T, int IndexCount>
+template <typename T, int IndexCount, T DefaultValue>
struct Array {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
return data[index];
@@ -38,16 +38,57 @@ struct Array {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
return data[index];
}
- int data[IndexCount];
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() {
+ for (int i = 0; i < IndexCount; i++) {
+ data[i] = DefaultValue;
+ }
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) {
+ data[0] = a0;
+ for (int i = 1; i < IndexCount; i++) {
+ data[i] = DefaultValue;
+ }
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) {
+ data[0] = a0;
+ data[1] = a1;
+ for (int i = 2; i < IndexCount; i++) {
+ data[i] = DefaultValue;
+ }
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) {
+ data[0] = a0;
+ data[1] = a1;
+ data[2] = a2;
+ for (int i = 3; i < IndexCount; i++) {
+ data[i] = DefaultValue;
+ }
+ }
+ T data[IndexCount];
};
// A dimension type with compile-time known size.
template <int IndexCount>
-struct Dimension : Array<int, IndexCount> {};
+struct Dimension : Array<int, IndexCount, 1> {
+ typedef Array<int, IndexCount, 1> Base;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1)
+ : Base(a0, a1) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2)
+ : Base(a0, a1, a2) {}
+};
// An index type with compile-time known size.
template <int IndexCount>
-struct Index : Array<int, IndexCount> {};
+struct Index : Array<int, IndexCount, 0> {
+ typedef Array<int, IndexCount, 0> Base;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2)
+ : Base(a0, a1, a2) {}
+};
// A helper function that converts a tensor index into a flat array index.
template <int IndexCount>
@@ -94,7 +135,7 @@ __global__ void SwapDimension0And2InTensor3(int nthreads, const T* input,
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
- output[output_index] = __ldg(input + input_index);
+ output[output_index] = ldg(input + input_index);
}
}
@@ -119,7 +160,88 @@ __global__ void SwapDimension1And2InTensor3(int nthreads, const T* input,
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
- output[output_index] = __ldg(input + input_index);
+ output[output_index] = ldg(input + input_index);
+ }
+}
+
+// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
+// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
+// TileSize could be arbitrary. But for best performance, it is better to be
+// the same as number of threads in a warp, which is 32 for almost all GPU
+// architectures.
+template <typename T, int TileSize>
+__global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
+ Dimension<3> input_dims,
+ T* output) {
+ // One extra line in the inner dimension to avoid share memory bank conflict.
+ __shared__ T shared_memory_tile[TileSize][TileSize + 1];
+
+ int x = threadIdx.x;
+ if (x >= TileSize) {
+ return;
+ }
+
+ Dimension<3> output_dims = {
+ input_dims[0], input_dims[2], input_dims[1],
+ };
+
+ Dimension<3> input_dims_in_tiles = {
+ input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
+ (input_dims[2] + TileSize - 1) / TileSize,
+ };
+
+ Index<3> input_tile_index =
+ FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
+
+ Index<3> input_tile_origin = {
+ input_tile_index[0], input_tile_index[1] * TileSize,
+ input_tile_index[2] * TileSize,
+ };
+
+ int input_origin_flat_index =
+ TensorIndexToFlat(input_tile_origin, input_dims);
+
+ int tile_width = TileSize;
+ // Only the last row or column may not have the full size.
+ if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
+ tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSize;
+ }
+ int tile_height = TileSize;
+ if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
+ tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize;
+ }
+
+ // Load the data from input memory to the shared memory tile.
+ if (x < tile_width) {
+ int input_flat_index = input_origin_flat_index + x;
+ for (int y = 0; y < tile_height; y++) {
+ shared_memory_tile[y][x] = input[input_flat_index];
+ input_flat_index += input_dims[2];
+ }
+ }
+
+ __syncthreads();
+
+ Index<3> output_tile_index = {
+ input_tile_index[0], input_tile_index[2], input_tile_index[1],
+ };
+
+ Index<3> output_tile_origin = {
+ output_tile_index[0], output_tile_index[1] * TileSize,
+ output_tile_index[2] * TileSize,
+ };
+
+ int output_origin_flat_index =
+ TensorIndexToFlat(output_tile_origin, output_dims);
+
+ int output_flat_index = output_origin_flat_index + x;
+
+ // Load the data from the shared memory tile to the output memory.
+ if (x < tile_height) {
+ for (int y = 0; y < tile_width; y++) {
+ output[output_flat_index] = shared_memory_tile[x][y];
+ output_flat_index += output_dims[2];
+ }
}
}
@@ -212,6 +334,39 @@ struct PadInput<GPUDevice, T, int> {
}
};
+// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
+// 3D tensor. It looks at the shape of the incoming data, and decides the best
+// strategy to launch.
+template <typename T>
+void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
+ const Dimension<3>& input_dims, T* output) {
+ // If both dimensions are not trivial, use tiles for the actual swapping.
+ // Otherwise, the trivial swapping relying on the ldg cache is more efficient.
+ static const int kMinDimensionToUseTiles = 16;
+ bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
+ input_dims[2] >= kMinDimensionToUseTiles);
+ if (use_tiles) {
+ // The tile-size can be chosen to be arbitrary number. But it is better to
+ // be the same as number of threads in a warp, which is 32.
+ static const int TileSize = 32;
+ Dimension<3> input_dims_in_tiles = {
+ input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
+ (input_dims[2] + TileSize - 1) / TileSize,
+ };
+ int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
+ input_dims_in_tiles[2];
+ SwapDimension1And2InTensor3UsingTiles<
+ T, TileSize><<<total_tiles_count, TileSize, 0, d.stream()>>>(
+ input, input_dims, output);
+ } else {
+ int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
+ SwapDimension1And2InTensor3<
+ T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ config.virtual_thread_count, input, input_dims, output);
+ }
+}
+
// A GPU helper functor that converts NHWC TensorFlow data format to
// NCHW format that is accepted by Cudnn.
template <typename T>
@@ -223,10 +378,7 @@ struct NHWCToNCHW<GPUDevice, T> {
combined_dims[0] = in.dimension(0);
combined_dims[1] = in.dimension(1) * in.dimension(2);
combined_dims[2] = in.dimension(3);
- CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension1And2InTensor3<
- T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- config.virtual_thread_count, in.data(), combined_dims, out.data());
+ RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
}
};
@@ -241,10 +393,7 @@ struct NCHWToNHWC<GPUDevice, T> {
combined_dims[0] = in.dimension(0);
combined_dims[1] = in.dimension(1);
combined_dims[2] = in.dimension(2) * in.dimension(3);
- CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension1And2InTensor3<
- T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
- config.virtual_thread_count, in.data(), combined_dims, out.data());
+ RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
}
};