diff options
Diffstat (limited to 'tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc | 486 |
1 files changed, 283 insertions, 203 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 5390222b3a..2a25459194 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -165,15 +165,18 @@ __global__ void __launch_bounds__(1024, 2) // one each in the lower and upper half of a tile. // Backprop input direction is the same as forward direction with the filter // rotated by 180°. +// T is the tensors' data type. S is the math type the kernel uses. This is the +// same as T for all cases but pseudo half (which has T=Eigen::half, S=float). template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, - bool kKnownEvenHeight> + bool kKnownEvenHeight, typename S> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); // Holds block plus halo and filter data for blockDim.x depths. - extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; - T* const shared_data = reinterpret_cast<T*>(shared_memory); + extern __shared__ __align__(8) unsigned char shared_memory[]; + static_assert(sizeof(S) <= 8, "Insufficient alignement detected"); + S* const shared_data = reinterpret_cast<S*>(shared_memory); const int num_batches = args.batch; const int in_height = args.in_rows; @@ -219,7 +222,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( // Initialize tile, in particular the padding. for (int i = thread_idx; i < tile_size; i += block_size) { - shared_data[i] = T(0); + shared_data[i] = S(); } __syncthreads(); @@ -254,14 +257,15 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( if (channel_in_range) { const T* const in_ptr = inout_offset + input; - T* const tile_ptr = tile_idx + shared_data; - tile_ptr[0] = ldg(in_ptr); + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast<S>(ldg(in_ptr)); if (!skip_second) { - tile_ptr[tile_offset] = ldg(tensor_offset + in_ptr); + tile_ptr[tile_offset] = static_cast<S>(ldg(tensor_offset + in_ptr)); } if (filter_write_offset != 0) { - shared_data[filter_write_offset] = ldg(filter_offset + filter); + shared_data[filter_write_offset] = + static_cast<S>(ldg(filter_offset + filter)); } } @@ -269,17 +273,17 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( __syncthreads(); if (channel_in_range) { - T sum1 = static_cast<T>(0); - T sum2 = static_cast<T>(0); + S sum1 = S(); + S sum2 = S(); int shared_offset = data_idx; - const T* filter_ptr = filter_read_offset + shared_data; + const S* filter_ptr = filter_read_offset + shared_data; UNROLL for (int r = 0; r < filter_height; ++r) { UNROLL for (int c = 0; c < filter_width; ++c) { if (kDirection == DIRECTION_BACKWARD) { filter_ptr -= kBlockDepth; } - const T filter_value = *filter_ptr; - const T* const tile_ptr = shared_offset + shared_data; + const S filter_value = *filter_ptr; + const S* const tile_ptr = shared_offset + shared_data; sum1 += filter_value * tile_ptr[0]; sum2 += filter_value * tile_ptr[tile_offset]; shared_offset += kBlockDepth; @@ -290,9 +294,9 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( shared_offset += in_increment; } T* const out_ptr = inout_offset + output; - out_ptr[0] = sum1; + out_ptr[0] = static_cast<T>(sum1); if (!skip_second) { - out_ptr[tensor_offset] = sum2; + out_ptr[tensor_offset] = static_cast<T>(sum2); } } @@ -445,15 +449,18 @@ __global__ void __launch_bounds__(1024, 2) // one each in the lower and upper half of a tile. // Backprop input direction is the same as forward direction with the filter // rotated by 180°. +// T is the tensors' data type. S is the math type the kernel uses. This is the +// same as T for all cases but pseudo half (which has T=Eigen::half, S=float). template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, - bool kKnownEvenHeight> + bool kKnownEvenHeight, typename S> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); // Holds block plus halo and filter data for blockDim.z depths. - extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; - T* const shared_data = reinterpret_cast<T*>(shared_memory); + extern __shared__ __align__(8) unsigned char shared_memory[]; + static_assert(sizeof(S) <= 8, "Insufficient alignement detected"); + S* const shared_data = reinterpret_cast<S*>(shared_memory); const int num_batches = args.batch; const int in_height = args.in_rows; @@ -498,7 +505,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( // Initialize tile, in particular the padding. for (int i = thread_idx; i < tile_size; i += block_size) { - shared_data[i] = T(0); + shared_data[i] = S(); } __syncthreads(); @@ -534,34 +541,35 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( if (channel_in_range) { const T* const in_ptr = inout_offset + input; - T* const tile_ptr = tile_idx + shared_data; - tile_ptr[0] = ldg(in_ptr); + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast<S>(ldg(in_ptr)); if (!skip_second) { - tile_ptr[tile_offset] = ldg(block_pixels + in_ptr); + tile_ptr[tile_offset] = static_cast<S>(ldg(block_pixels + in_ptr)); } } if (filter_write_offset != 0) { const int filter_offset = filter_idx + (channel + filter_channel) % in_depth; - shared_data[filter_write_offset] = ldg(filter_offset + filter); + shared_data[filter_write_offset] = + static_cast<S>(ldg(filter_offset + filter)); } // Note: the condition to reach this is uniform across the entire block. __syncthreads(); if (channel_in_range) { - T sum1 = static_cast<T>(0); - T sum2 = static_cast<T>(0); + S sum1 = S(); + S sum2 = S(); int shared_offset = data_idx; - const T* filter_ptr = filter_read_offset + shared_data; + const S* filter_ptr = filter_read_offset + shared_data; UNROLL for (int r = 0; r < filter_height; ++r) { UNROLL for (int c = 0; c < filter_width; ++c) { if (kDirection == DIRECTION_BACKWARD) { filter_ptr -= kBlockDepth; } - const T filter_value = *filter_ptr; - const T* const tile_ptr = shared_offset + shared_data; + const S filter_value = *filter_ptr; + const S* const tile_ptr = shared_offset + shared_data; sum1 += filter_value * tile_ptr[0]; sum2 += filter_value * tile_ptr[tile_offset]; ++shared_offset; @@ -572,9 +580,9 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( shared_offset += in_increment; } T* const out_ptr = inout_offset + output; - out_ptr[0] = sum1; + out_ptr[0] = static_cast<T>(sum1); if (!skip_second) { - out_ptr[block_pixels] = sum2; + out_ptr[block_pixels] = static_cast<T>(sum2); } } @@ -585,11 +593,11 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, - bool kKnownEvenHeight> -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, - const DepthwiseArgs& args, const T* input, - const T* filter, T* output, - TensorFormat data_format) { + bool kKnownEvenHeight, typename S> +Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, + TensorFormat data_format) { const int block_height = (args.in_rows + 1) / 2; dim3 block_dim; int block_count; @@ -602,7 +610,7 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, kernel = DepthwiseConv2dGPUKernelNHWCSmall<T, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, - kKnownEvenHeight>; + kKnownEvenHeight, S>; break; case FORMAT_NCHW: block_dim = dim3(args.in_cols, block_height, kBlockDepth); @@ -611,73 +619,126 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, kernel = DepthwiseConv2dGPUKernelNCHWSmall<T, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, - kKnownEvenHeight>; + kKnownEvenHeight, S>; break; default: - LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; - return; + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); } const int tile_width = args.in_cols + args.filter_cols - 1; const int tile_height = block_height * 2 + args.filter_rows - 1; const int tile_pixels = tile_height * tile_width; const int filter_pixels = args.filter_rows * args.filter_cols; const int shared_memory_size = - kBlockDepth * (tile_pixels + filter_pixels) * sizeof(T); + kBlockDepth * (tile_pixels + filter_pixels) * sizeof(S); const int num_outputs = args.out_rows * args.out_cols * block_count; + auto device = ctx->eigen_gpu_device(); CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( num_outputs, device, kernel, shared_memory_size, block_dim.x * block_dim.y * block_dim.z); kernel<<<config.block_count, block_dim, shared_memory_size, device.stream()>>>(args, input, filter, output); + return Status::OK(); +} + +namespace detail { +template <typename T> +struct PseudoHalfType { + using Type = T; +}; +template <> +struct PseudoHalfType<Eigen::half> { + using Type = float; +}; +} // namespace detail + +namespace { +// Maps to float if T is __half, and to T otherwise. +template <typename T> +using PseudoHalfType = typename detail::PseudoHalfType<T>::Type; + +// Returns whether the context's GPU supports efficient fp16 math. +bool HasFastHalfMath(OpKernelContext* ctx) { + int major, minor; + ctx->op_device_context() + ->stream() + ->parent() + ->GetDeviceDescription() + .cuda_compute_capability(&major, &minor); + auto cuda_arch = major * 100 + minor * 10; + // GPUs before sm_53 don't support fp16 math, and sm_61's fp16 math is slow. + return cuda_arch >= 530 && cuda_arch != 610; +} +} // namespace + +template <typename T, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> +Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, + TensorFormat data_format) { +#if !defined __CUDA_ARCH__ || __CUDA_ARCH__ >= 530 + if (HasFastHalfMath(ctx)) { + return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, + kKnownFilterHeight, kBlockDepth, + kKnownEvenHeight, T>( + ctx, args, input, filter, output, data_format); + } +#endif + return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, + kKnownFilterHeight, kBlockDepth, + kKnownEvenHeight, PseudoHalfType<T>>( + ctx, args, input, filter, output, data_format); } template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth> -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, - const DepthwiseArgs& args, const T* input, - const T* filter, T* output, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, + TensorFormat data_format) { if (args.in_rows & 1) { - LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, - kKnownFilterHeight, kBlockDepth, false>( - device, args, input, filter, output, data_format); + return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, + kKnownFilterHeight, kBlockDepth, + false>(ctx, args, input, filter, + output, data_format); } else { - LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, - kKnownFilterHeight, kBlockDepth, true>( - device, args, input, filter, output, data_format); + return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, + kKnownFilterHeight, kBlockDepth, true>( + ctx, args, input, filter, output, data_format); } } template <typename T, DepthwiseConv2dDirection kDirection, int kKnownFilterWidth, int kKnownFilterHeight> -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, - const DepthwiseArgs& args, const T* input, - const T* filter, T* output, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, + TensorFormat data_format) { // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols; if (block_pixels > 256) { - LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, - kKnownFilterHeight, 2>( - device, args, input, filter, output, data_format); + return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, + kKnownFilterHeight, 2>( + ctx, args, input, filter, output, data_format); } else if (block_pixels > 128) { - LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, - kKnownFilterHeight, 4>( - device, args, input, filter, output, data_format); + return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, + kKnownFilterHeight, 4>( + ctx, args, input, filter, output, data_format); } else { - LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, - kKnownFilterHeight, 8>( - device, args, input, filter, output, data_format); + return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth, + kKnownFilterHeight, 8>( + ctx, args, input, filter, output, data_format); } } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kKnownDepthMultiplier> -void LaunchDepthwiseConv2dGPU(const GpuDevice& device, - const DepthwiseArgs& args, const T* input, - const T* filter, T* output, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); switch (data_format) { case FORMAT_NHWC: @@ -691,11 +752,12 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& device, kKnownDepthMultiplier>; break; default: - LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; - return; + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); } const int num_outputs = args.batch * args.out_rows * args.out_cols * args.out_depth; + auto device = ctx->eigen_gpu_device(); CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, device, kernel, 0, 0); // The compile-time constant version runs faster with a single block. @@ -706,26 +768,27 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& device, kernel<<<std::min(max_block_count, config.block_count), config.thread_per_block, 0, device.stream()>>>(args, input, filter, output, num_outputs); + return Status::OK(); } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> -void LaunchDepthwiseConv2dGPU(const GpuDevice& device, - const DepthwiseArgs& args, const T* input, - const T* filter, T* output, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format) { if (args.depth_multiplier == 1) { if (CanLaunchDepthwiseConv2dGPUSmall(args)) { - LaunchDepthwiseConv2dGPUSmall<T, DIRECTION_FORWARD, kKnownFilterWidth, - kKnownFilterHeight>( - device, args, input, filter, output, data_format); - return; + return LaunchDepthwiseConv2dGPUSmall< + T, DIRECTION_FORWARD, kKnownFilterWidth, kKnownFilterHeight>( + ctx, args, input, filter, output, data_format); } - LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, 1>( - device, args, input, filter, output, data_format); + return LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, + 1>(ctx, args, input, filter, output, + data_format); } else { - LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, -1>( - device, args, input, filter, output, data_format); + return LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, + -1>(ctx, args, input, filter, output, + data_format); } } @@ -736,18 +799,13 @@ void LaunchDepthwiseConvOp<GpuDevice, T>::operator()(OpKernelContext* ctx, const T* input, const T* filter, T* output, TensorFormat data_format) { - const GpuDevice& device = ctx->eigen_device<GpuDevice>(); if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dGPU<T, 3, 3>(device, args, input, filter, output, - data_format); + OP_REQUIRES_OK(ctx, LaunchDepthwiseConv2dGPU<T, 3, 3>( + ctx, args, input, filter, output, data_format)); } else { - LaunchDepthwiseConv2dGPU<T, -1, -1>(device, args, input, filter, output, - data_format); + OP_REQUIRES_OK(ctx, LaunchDepthwiseConv2dGPU<T, -1, -1>( + ctx, args, input, filter, output, data_format)); } - auto stream = ctx->op_device_context()->stream(); - OP_REQUIRES(ctx, stream->ok(), - errors::Internal( - "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed")); } template struct LaunchDepthwiseConvOp<GpuDevice, Eigen::half>; @@ -904,11 +962,11 @@ __global__ void __launch_bounds__(640, 2) template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kKnownDepthMultiplier> -void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, - const DepthwiseArgs& args, - const T* out_backprop, - const T* filter, T* in_backprop, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx, + const DepthwiseArgs& args, + const T* out_backprop, + const T* filter, T* in_backprop, + TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); switch (data_format) { case FORMAT_NHWC: @@ -920,38 +978,39 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; break; default: - LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; - return; + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); } const int num_in_backprop = args.batch * args.in_rows * args.in_cols * args.in_depth; + auto device = ctx->eigen_gpu_device(); CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, device, kernel, 0, 0); kernel<<<config.block_count, config.thread_per_block, 0, device.stream()>>>( args, out_backprop, filter, in_backprop, num_in_backprop); + return Status::OK(); } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> -void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, - const DepthwiseArgs& args, - const T* out_backprop, - const T* filter, T* in_backprop, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx, + const DepthwiseArgs& args, + const T* out_backprop, + const T* filter, T* in_backprop, + TensorFormat data_format) { if (args.depth_multiplier == 1) { if (CanLaunchDepthwiseConv2dGPUSmall(args)) { - LaunchDepthwiseConv2dGPUSmall<T, DIRECTION_BACKWARD, kKnownFilterWidth, - kKnownFilterHeight>( - device, args, out_backprop, filter, in_backprop, data_format); - return; + return LaunchDepthwiseConv2dGPUSmall< + T, DIRECTION_BACKWARD, kKnownFilterWidth, kKnownFilterHeight>( + ctx, args, out_backprop, filter, in_backprop, data_format); } - LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth, - kKnownFilterHeight, 1>( - device, args, out_backprop, filter, in_backprop, data_format); + return LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth, + kKnownFilterHeight, 1>( + ctx, args, out_backprop, filter, in_backprop, data_format); } else { - LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth, - kKnownFilterHeight, -1>( - device, args, out_backprop, filter, in_backprop, data_format); + return LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth, + kKnownFilterHeight, -1>( + ctx, args, out_backprop, filter, in_backprop, data_format); } } @@ -960,19 +1019,15 @@ template <typename T> void LaunchDepthwiseConvBackpropInputOp<GpuDevice, T>::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { - const GpuDevice& device = ctx->eigen_device<GpuDevice>(); if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>( - device, args, out_backprop, filter, in_backprop, data_format); + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>( + ctx, args, out_backprop, filter, in_backprop, data_format)); } else { - LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>( - device, args, out_backprop, filter, in_backprop, data_format); + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>( + ctx, args, out_backprop, filter, in_backprop, data_format)); } - auto stream = ctx->op_device_context()->stream(); - OP_REQUIRES(ctx, stream->ok(), - errors::Internal("Launch of gpu kernel for " - "DepthwiseConv2dBackpropInp" - "utGPULaunch failed")); } template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, Eigen::half>; @@ -1111,15 +1166,18 @@ __device__ __forceinline__ T WarpSumReduce(T val) { // up in global memory using atomics. // Requirements: threads per block must be multiple of 32 and <= launch_bounds, // kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. +// T is the tensors' data type. S is the math type the kernel uses. This is the +// same as T for all cases but pseudo half (which has T=Eigen::half, S=float). template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, - int kBlockDepth, int kAccumPixels> + int kBlockDepth, int kAccumPixels, typename S> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z)); // Holds block plus halo and filter data for blockDim.x depths. - extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; - T* const shared_data = reinterpret_cast<T*>(shared_memory); + extern __shared__ __align__(8) unsigned char shared_memory[]; + static_assert(sizeof(S) <= 8, "Insufficient alignement detected"); + S* const shared_data = reinterpret_cast<S*>(shared_memory); const int num_batches = args.batch; const int in_height = args.in_rows; @@ -1169,7 +1227,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( // Initialize tile, in particular the padding and accumulator. for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { - shared_data[i] = T(0); + shared_data[i] = S(); } __syncthreads(); @@ -1203,10 +1261,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( if (channel_in_range) { const T* const in_ptr = inout_offset + input; - T* const tile_ptr = tile_idx + shared_data; - tile_ptr[0] = ldg(in_ptr); + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast<S>(ldg(in_ptr)); if (!skip_second) { - tile_ptr[tile_offset] = ldg(tensor_offset + in_ptr); + tile_ptr[tile_offset] = static_cast<S>(ldg(tensor_offset + in_ptr)); } } @@ -1216,14 +1274,15 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( if (channel_in_range) { const T* const out_ptr = inout_offset + output; - const T out1 = ldg(out_ptr); - const T out2 = skip_second ? T(0) : ldg(tensor_offset + out_ptr); + const S out1 = static_cast<S>(ldg(out_ptr)); + const S out2 = + skip_second ? S() : static_cast<S>(ldg(tensor_offset + out_ptr)); int shared_offset = data_idx; - T* accum_ptr = accum_offset + shared_data; + S* accum_ptr = accum_offset + shared_data; UNROLL for (int r = 0; r < filter_height; ++r) { UNROLL for (int c = 0; c < filter_width; ++c) { - const T* const tile_ptr = shared_offset + shared_data; - T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; + const S* const tile_ptr = shared_offset + shared_data; + S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. for (int delta = 16; delta >= kBlockDepth; delta /= 2) { val += CudaShuffleXorSync(active_threads, val, delta); @@ -1241,18 +1300,18 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - const T* const accum_data = tile_size + shared_data; + const S* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { const int filter_idx = i / kAccumPixels; const int filter_pix = filter_idx / kBlockDepth; const int filter_channel = filter_idx % kBlockDepth + start_channel; const int filter_offset = filter_pix * in_depth + filter_channel; if (filter_channel < in_depth) { - T val = accum_data[i]; + S val = accum_data[i]; // Warp-accumulate the pixels of the same depth from the accumulator. val = WarpSumReduce<kAccumPixels>(val); if (!(thread_idx & kAccumPixels - 1)) { - CudaAtomicAdd(filter_offset + filter, val); + CudaAtomicAdd(filter_offset + filter, static_cast<T>(val)); } } } @@ -1382,14 +1441,15 @@ __global__ void __launch_bounds__(640, 2) // Requirements: threads per block must be multiple of 32 and <= launch_bounds, // kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, - int kBlockDepth, int kAccumPixels> + int kBlockDepth, int kAccumPixels, typename S> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x)); // Holds block plus halo and filter data for blockDim.z depths. - extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; - T* const shared_data = reinterpret_cast<T*>(shared_memory); + extern __shared__ __align__(8) unsigned char shared_memory[]; + static_assert(sizeof(S) <= 8, "Insufficient alignement detected"); + S* const shared_data = reinterpret_cast<S*>(shared_memory); const int num_batches = args.batch; const int in_height = args.in_rows; @@ -1438,7 +1498,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( // Initialize tile, in particular the padding and accumulator. for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { - shared_data[i] = T(0); + shared_data[i] = S(); } __syncthreads(); @@ -1468,10 +1528,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( if (channel_in_range) { const T* const in_ptr = inout_offset + input; - T* const tile_ptr = tile_idx + shared_data; - tile_ptr[0] = ldg(in_ptr); + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast<S>(ldg(in_ptr)); if (!skip_second) { - tile_ptr[tile_offset] = ldg(block_pixels + in_ptr); + tile_ptr[tile_offset] = static_cast<S>(ldg(block_pixels + in_ptr)); } } @@ -1481,14 +1541,15 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( if (channel_in_range) { const T* const out_ptr = inout_offset + output; - const T out1 = ldg(out_ptr); - const T out2 = skip_second ? T(0) : ldg(block_pixels + out_ptr); + const S out1 = static_cast<S>(ldg(out_ptr)); + const S out2 = + skip_second ? S() : static_cast<S>(ldg(block_pixels + out_ptr)); int shared_offset = data_idx; - T* accum_ptr = accum_offset + shared_data; + S* accum_ptr = accum_offset + shared_data; UNROLL for (int r = 0; r < filter_height; ++r) { UNROLL for (int c = 0; c < filter_width; ++c) { - const T* const tile_ptr = shared_offset + shared_data; - T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; + const S* const tile_ptr = shared_offset + shared_data; + S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. for (int delta = 16 / kBlockDepth; delta > 0; delta /= 2) { val += CudaShuffleXorSync(active_threads, val, delta); @@ -1506,7 +1567,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - const T* const accum_data = tile_size + shared_data; + const S* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { const int filter_idx = i / kAccumPixels; const int filter_pix = filter_idx / kBlockDepth; @@ -1514,11 +1575,11 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( (channel + filter_idx % kBlockDepth) % in_depth; const int filter_offset = filter_pix * in_depth + filter_channel; if (filter_channel < in_depth) { - T val = accum_data[i]; + S val = accum_data[i]; // Warp-accumulate pixels of the same depth from the accumulator. val = WarpSumReduce<kAccumPixels>(val); if (!(thread_idx & kAccumPixels - 1)) { - CudaAtomicAdd(filter_offset + filter, val); + CudaAtomicAdd(filter_offset + filter, static_cast<T>(val)); } } } @@ -1526,19 +1587,20 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, - int kBlockDepth, int kAccumPixels> -bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& device, const DepthwiseArgs& args, const int block_height, + int kBlockDepth, int kAccumPixels, typename S> +Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { + auto device = ctx->eigen_gpu_device(); const int tile_width = args.in_cols + args.filter_cols - 1; const int tile_height = block_height * 2 + args.filter_rows - 1; const int tile_pixels = tile_height * tile_width; const int filter_pixels = args.filter_rows * args.filter_cols; const int shared_memory_size = - kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T); + kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(S); if (shared_memory_size > device.sharedMemPerBlock()) { - return false; + return errors::FailedPrecondition("Not enough shared memory"); } dim3 block_dim; @@ -1550,18 +1612,20 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( block_count = args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, + S>; break; case FORMAT_NCHW: block_dim = dim3(args.in_cols, block_height, kBlockDepth); block_count = DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, + S>; break; default: - LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; - return false; + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); } const int num_out_backprop = args.out_rows * args.out_cols * block_count; CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( @@ -1569,13 +1633,33 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( block_dim.x * block_dim.y * block_dim.z); kernel<<<config.block_count, block_dim, shared_memory_size, device.stream()>>>(args, out_backprop, input, filter_backprop); - return true; + return Status::OK(); +} + +template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, + int kBlockDepth, int kAccumPixels> +Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format) { +#if !defined __CUDA_ARCH__ || __CUDA_ARCH__ >= 530 + if (HasFastHalfMath(ctx)) { + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, T>( + ctx, args, block_height, out_backprop, input, filter_backprop, + data_format); + } +#endif + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, + PseudoHalfType<T>>(ctx, args, block_height, out_backprop, input, + filter_backprop, data_format); } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth> -bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& device, const DepthwiseArgs& args, const int block_height, +Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { // Minimize (power of two) kAccumPixels, while satisfying @@ -1584,24 +1668,24 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( if (block_pixels > 512) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 32>( - device, args, block_height, out_backprop, input, filter_backprop, + ctx, args, block_height, out_backprop, input, filter_backprop, data_format); } else if (block_pixels > 256) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 16>( - device, args, block_height, out_backprop, input, filter_backprop, + ctx, args, block_height, out_backprop, input, filter_backprop, data_format); } else { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 8>( - device, args, block_height, out_backprop, input, filter_backprop, + ctx, args, block_height, out_backprop, input, filter_backprop, data_format); } } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> -bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop, +Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). @@ -1621,37 +1705,35 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( } if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_height)) { - return false; + return errors::FailedPrecondition("Cannot launch this configuration"); } switch (block_depth) { case 8: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 8>( - device, args, block_height, out_backprop, input, filter_backprop, + ctx, args, block_height, out_backprop, input, filter_backprop, data_format); case 4: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 4>( - device, args, block_height, out_backprop, input, filter_backprop, + ctx, args, block_height, out_backprop, input, filter_backprop, data_format); case 2: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 2>( - device, args, block_height, out_backprop, input, filter_backprop, + ctx, args, block_height, out_backprop, input, filter_backprop, data_format); default: - return false; + return errors::InvalidArgument("Unexpected block depth"); } } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, int kKnownDepthMultiplier> -void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, - const DepthwiseArgs& args, - const T* out_backprop, - const T* input, T* filter_backprop, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dBackpropFilterGPU( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); switch (data_format) { case FORMAT_NHWC: @@ -1663,37 +1745,38 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; break; default: - LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported"; - return; + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); } const int num_out_backprop = args.batch * args.out_rows * args.out_cols * args.out_depth; + auto device = ctx->eigen_gpu_device(); CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, device, kernel, 0, 0); kernel<<<config.block_count, config.thread_per_block, 0, device.stream()>>>( args, out_backprop, input, filter_backprop, num_out_backprop); + return Status::OK(); } template <typename T, int kKnownFilterWidth, int kKnownFilterHeight> -void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, - const DepthwiseArgs& args, - const T* out_backprop, - const T* input, T* filter_backprop, - TensorFormat data_format) { +Status LaunchDepthwiseConv2dBackpropFilterGPU( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { if (args.depth_multiplier == 1) { if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth, kKnownFilterHeight>( - device, args, out_backprop, input, filter_backprop, data_format)) { - return; + ctx, args, out_backprop, input, filter_backprop, data_format) + .ok()) { + return Status::OK(); } - LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth, - kKnownFilterHeight, 1>( - device, args, out_backprop, input, filter_backprop, data_format); + return LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth, + kKnownFilterHeight, 1>( + ctx, args, out_backprop, input, filter_backprop, data_format); } else { - LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth, - kKnownFilterHeight, -1>( - device, args, out_backprop, input, filter_backprop, data_format); + return LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth, + kKnownFilterHeight, -1>( + ctx, args, out_backprop, input, filter_backprop, data_format); } } @@ -1702,7 +1785,6 @@ template <typename T> void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - const GpuDevice& device = ctx->eigen_device<GpuDevice>(); auto stream = ctx->op_device_context()->stream(); // Initialize the results to 0. @@ -1712,16 +1794,14 @@ void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()( stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T)); if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>( - device, args, out_backprop, input, filter_backprop, data_format); + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>( + ctx, args, out_backprop, input, filter_backprop, data_format)); } else { - LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>( - device, args, out_backprop, input, filter_backprop, data_format); + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>( + ctx, args, out_backprop, input, filter_backprop, data_format)); } - OP_REQUIRES(ctx, stream->ok(), - errors::Internal("Launch of gpu kernel for " - "DepthwiseConv2dBackpropFil" - "terGPULaunch failed")); } template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, Eigen::half>; |