diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-25 08:48:47 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-25 08:48:47 -0800 |
commit | 854f49bd43588c062b046384f239f64a3d819702 (patch) | |
tree | c2373bf71ef65ae4c116ea703947141281c1eace /tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | |
parent | 9c3043ff3bf31a6a81810b4ce9e87ef936f1f529 (diff) |
TensorFlow: Upstream changes to git
Changes:
- Updates to docs
- Several changes for Python 3 compatibility
- Added license headers
Base CL: 108710566
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_gpu_3.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | 177 |
1 files changed, 163 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index a40b3caefd..2d7e149c7b 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { // A simple array that contains data that can be passed between CPU and GPU. -template <typename T, int IndexCount> +template <typename T, int IndexCount, T DefaultValue> struct Array { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const { return data[index]; @@ -38,16 +38,57 @@ struct Array { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) { return data[index]; } - int data[IndexCount]; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() { + for (int i = 0; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) { + data[0] = a0; + for (int i = 1; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) { + data[0] = a0; + data[1] = a1; + for (int i = 2; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) { + data[0] = a0; + data[1] = a1; + data[2] = a2; + for (int i = 3; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + T data[IndexCount]; }; // A dimension type with compile-time known size. template <int IndexCount> -struct Dimension : Array<int, IndexCount> {}; +struct Dimension : Array<int, IndexCount, 1> { + typedef Array<int, IndexCount, 1> Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1) + : Base(a0, a1) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2) + : Base(a0, a1, a2) {} +}; // An index type with compile-time known size. template <int IndexCount> -struct Index : Array<int, IndexCount> {}; +struct Index : Array<int, IndexCount, 0> { + typedef Array<int, IndexCount, 0> Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2) + : Base(a0, a1, a2) {} +}; // A helper function that converts a tensor index into a flat array index. template <int IndexCount> @@ -94,7 +135,7 @@ __global__ void SwapDimension0And2InTensor3(int nthreads, const T* input, int input_index = TensorIndexToFlat(input_tensor_index, input_dims); - output[output_index] = __ldg(input + input_index); + output[output_index] = ldg(input + input_index); } } @@ -119,7 +160,88 @@ __global__ void SwapDimension1And2InTensor3(int nthreads, const T* input, int input_index = TensorIndexToFlat(input_tensor_index, input_dims); - output[output_index] = __ldg(input + input_index); + output[output_index] = ldg(input + input_index); + } +} + +// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor, +// where dimensions are zero-based: output[i][j][k] = input[i][k][j]. +// TileSize could be arbitrary. But for best performance, it is better to be +// the same as number of threads in a warp, which is 32 for almost all GPU +// architectures. +template <typename T, int TileSize> +__global__ void SwapDimension1And2InTensor3UsingTiles(const T* input, + Dimension<3> input_dims, + T* output) { + // One extra line in the inner dimension to avoid share memory bank conflict. + __shared__ T shared_memory_tile[TileSize][TileSize + 1]; + + int x = threadIdx.x; + if (x >= TileSize) { + return; + } + + Dimension<3> output_dims = { + input_dims[0], input_dims[2], input_dims[1], + }; + + Dimension<3> input_dims_in_tiles = { + input_dims[0], (input_dims[1] + TileSize - 1) / TileSize, + (input_dims[2] + TileSize - 1) / TileSize, + }; + + Index<3> input_tile_index = + FlatToTensorIndex(blockIdx.x, input_dims_in_tiles); + + Index<3> input_tile_origin = { + input_tile_index[0], input_tile_index[1] * TileSize, + input_tile_index[2] * TileSize, + }; + + int input_origin_flat_index = + TensorIndexToFlat(input_tile_origin, input_dims); + + int tile_width = TileSize; + // Only the last row or column may not have the full size. + if (input_tile_index[2] == input_dims_in_tiles[2] - 1) { + tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSize; + } + int tile_height = TileSize; + if (input_tile_index[1] == input_dims_in_tiles[1] - 1) { + tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize; + } + + // Load the data from input memory to the shared memory tile. + if (x < tile_width) { + int input_flat_index = input_origin_flat_index + x; + for (int y = 0; y < tile_height; y++) { + shared_memory_tile[y][x] = input[input_flat_index]; + input_flat_index += input_dims[2]; + } + } + + __syncthreads(); + + Index<3> output_tile_index = { + input_tile_index[0], input_tile_index[2], input_tile_index[1], + }; + + Index<3> output_tile_origin = { + output_tile_index[0], output_tile_index[1] * TileSize, + output_tile_index[2] * TileSize, + }; + + int output_origin_flat_index = + TensorIndexToFlat(output_tile_origin, output_dims); + + int output_flat_index = output_origin_flat_index + x; + + // Load the data from the shared memory tile to the output memory. + if (x < tile_height) { + for (int y = 0; y < tile_width; y++) { + output[output_flat_index] = shared_memory_tile[x][y]; + output_flat_index += output_dims[2]; + } } } @@ -212,6 +334,39 @@ struct PadInput<GPUDevice, T, int> { } }; +// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a +// 3D tensor. It looks at the shape of the incoming data, and decides the best +// strategy to launch. +template <typename T> +void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, + const Dimension<3>& input_dims, T* output) { + // If both dimensions are not trivial, use tiles for the actual swapping. + // Otherwise, the trivial swapping relying on the ldg cache is more efficient. + static const int kMinDimensionToUseTiles = 16; + bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles && + input_dims[2] >= kMinDimensionToUseTiles); + if (use_tiles) { + // The tile-size can be chosen to be arbitrary number. But it is better to + // be the same as number of threads in a warp, which is 32. + static const int TileSize = 32; + Dimension<3> input_dims_in_tiles = { + input_dims[0], (input_dims[1] + TileSize - 1) / TileSize, + (input_dims[2] + TileSize - 1) / TileSize, + }; + int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] * + input_dims_in_tiles[2]; + SwapDimension1And2InTensor3UsingTiles< + T, TileSize><<<total_tiles_count, TileSize, 0, d.stream()>>>( + input, input_dims, output); + } else { + int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; + CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d); + SwapDimension1And2InTensor3< + T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + config.virtual_thread_count, input, input_dims, output); + } +} + // A GPU helper functor that converts NHWC TensorFlow data format to // NCHW format that is accepted by Cudnn. template <typename T> @@ -223,10 +378,7 @@ struct NHWCToNCHW<GPUDevice, T> { combined_dims[0] = in.dimension(0); combined_dims[1] = in.dimension(1) * in.dimension(2); combined_dims[2] = in.dimension(3); - CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension1And2InTensor3< - T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( - config.virtual_thread_count, in.data(), combined_dims, out.data()); + RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); } }; @@ -241,10 +393,7 @@ struct NCHWToNHWC<GPUDevice, T> { combined_dims[0] = in.dimension(0); combined_dims[1] = in.dimension(1); combined_dims[2] = in.dimension(2) * in.dimension(3); - CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension1And2InTensor3< - T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( - config.virtual_thread_count, in.data(), combined_dims, out.data()); + RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); } }; |