aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc486
1 files changed, 283 insertions, 203 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 5390222b3a..2a25459194 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -165,15 +165,18 @@ __global__ void __launch_bounds__(1024, 2)
// one each in the lower and upper half of a tile.
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
+// T is the tensors' data type. S is the math type the kernel uses. This is the
+// same as T for all cases but pseudo half (which has T=Eigen::half, S=float).
template <typename T, DepthwiseConv2dDirection kDirection,
int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
- bool kKnownEvenHeight>
+ bool kKnownEvenHeight, typename S>
__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
// Holds block plus halo and filter data for blockDim.x depths.
- extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
- T* const shared_data = reinterpret_cast<T*>(shared_memory);
+ extern __shared__ __align__(8) unsigned char shared_memory[];
+ static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
const int in_height = args.in_rows;
@@ -219,7 +222,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
// Initialize tile, in particular the padding.
for (int i = thread_idx; i < tile_size; i += block_size) {
- shared_data[i] = T(0);
+ shared_data[i] = S();
}
__syncthreads();
@@ -254,14 +257,15 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
if (channel_in_range) {
const T* const in_ptr = inout_offset + input;
- T* const tile_ptr = tile_idx + shared_data;
- tile_ptr[0] = ldg(in_ptr);
+ S* const tile_ptr = tile_idx + shared_data;
+ tile_ptr[0] = static_cast<S>(ldg(in_ptr));
if (!skip_second) {
- tile_ptr[tile_offset] = ldg(tensor_offset + in_ptr);
+ tile_ptr[tile_offset] = static_cast<S>(ldg(tensor_offset + in_ptr));
}
if (filter_write_offset != 0) {
- shared_data[filter_write_offset] = ldg(filter_offset + filter);
+ shared_data[filter_write_offset] =
+ static_cast<S>(ldg(filter_offset + filter));
}
}
@@ -269,17 +273,17 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
__syncthreads();
if (channel_in_range) {
- T sum1 = static_cast<T>(0);
- T sum2 = static_cast<T>(0);
+ S sum1 = S();
+ S sum2 = S();
int shared_offset = data_idx;
- const T* filter_ptr = filter_read_offset + shared_data;
+ const S* filter_ptr = filter_read_offset + shared_data;
UNROLL for (int r = 0; r < filter_height; ++r) {
UNROLL for (int c = 0; c < filter_width; ++c) {
if (kDirection == DIRECTION_BACKWARD) {
filter_ptr -= kBlockDepth;
}
- const T filter_value = *filter_ptr;
- const T* const tile_ptr = shared_offset + shared_data;
+ const S filter_value = *filter_ptr;
+ const S* const tile_ptr = shared_offset + shared_data;
sum1 += filter_value * tile_ptr[0];
sum2 += filter_value * tile_ptr[tile_offset];
shared_offset += kBlockDepth;
@@ -290,9 +294,9 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
shared_offset += in_increment;
}
T* const out_ptr = inout_offset + output;
- out_ptr[0] = sum1;
+ out_ptr[0] = static_cast<T>(sum1);
if (!skip_second) {
- out_ptr[tensor_offset] = sum2;
+ out_ptr[tensor_offset] = static_cast<T>(sum2);
}
}
@@ -445,15 +449,18 @@ __global__ void __launch_bounds__(1024, 2)
// one each in the lower and upper half of a tile.
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
+// T is the tensors' data type. S is the math type the kernel uses. This is the
+// same as T for all cases but pseudo half (which has T=Eigen::half, S=float).
template <typename T, DepthwiseConv2dDirection kDirection,
int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
- bool kKnownEvenHeight>
+ bool kKnownEvenHeight, typename S>
__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* input, const T* filter, T* output) {
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
// Holds block plus halo and filter data for blockDim.z depths.
- extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
- T* const shared_data = reinterpret_cast<T*>(shared_memory);
+ extern __shared__ __align__(8) unsigned char shared_memory[];
+ static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
const int in_height = args.in_rows;
@@ -498,7 +505,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
// Initialize tile, in particular the padding.
for (int i = thread_idx; i < tile_size; i += block_size) {
- shared_data[i] = T(0);
+ shared_data[i] = S();
}
__syncthreads();
@@ -534,34 +541,35 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
if (channel_in_range) {
const T* const in_ptr = inout_offset + input;
- T* const tile_ptr = tile_idx + shared_data;
- tile_ptr[0] = ldg(in_ptr);
+ S* const tile_ptr = tile_idx + shared_data;
+ tile_ptr[0] = static_cast<S>(ldg(in_ptr));
if (!skip_second) {
- tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
+ tile_ptr[tile_offset] = static_cast<S>(ldg(block_pixels + in_ptr));
}
}
if (filter_write_offset != 0) {
const int filter_offset =
filter_idx + (channel + filter_channel) % in_depth;
- shared_data[filter_write_offset] = ldg(filter_offset + filter);
+ shared_data[filter_write_offset] =
+ static_cast<S>(ldg(filter_offset + filter));
}
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
if (channel_in_range) {
- T sum1 = static_cast<T>(0);
- T sum2 = static_cast<T>(0);
+ S sum1 = S();
+ S sum2 = S();
int shared_offset = data_idx;
- const T* filter_ptr = filter_read_offset + shared_data;
+ const S* filter_ptr = filter_read_offset + shared_data;
UNROLL for (int r = 0; r < filter_height; ++r) {
UNROLL for (int c = 0; c < filter_width; ++c) {
if (kDirection == DIRECTION_BACKWARD) {
filter_ptr -= kBlockDepth;
}
- const T filter_value = *filter_ptr;
- const T* const tile_ptr = shared_offset + shared_data;
+ const S filter_value = *filter_ptr;
+ const S* const tile_ptr = shared_offset + shared_data;
sum1 += filter_value * tile_ptr[0];
sum2 += filter_value * tile_ptr[tile_offset];
++shared_offset;
@@ -572,9 +580,9 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
shared_offset += in_increment;
}
T* const out_ptr = inout_offset + output;
- out_ptr[0] = sum1;
+ out_ptr[0] = static_cast<T>(sum1);
if (!skip_second) {
- out_ptr[block_pixels] = sum2;
+ out_ptr[block_pixels] = static_cast<T>(sum2);
}
}
@@ -585,11 +593,11 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
template <typename T, DepthwiseConv2dDirection kDirection,
int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
- bool kKnownEvenHeight>
-void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
- const DepthwiseArgs& args, const T* input,
- const T* filter, T* output,
- TensorFormat data_format) {
+ bool kKnownEvenHeight, typename S>
+Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
+ const DepthwiseArgs& args, const T* input,
+ const T* filter, T* output,
+ TensorFormat data_format) {
const int block_height = (args.in_rows + 1) / 2;
dim3 block_dim;
int block_count;
@@ -602,7 +610,7 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
kernel =
DepthwiseConv2dGPUKernelNHWCSmall<T, kDirection, kKnownFilterWidth,
kKnownFilterHeight, kBlockDepth,
- kKnownEvenHeight>;
+ kKnownEvenHeight, S>;
break;
case FORMAT_NCHW:
block_dim = dim3(args.in_cols, block_height, kBlockDepth);
@@ -611,73 +619,126 @@ void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
kernel =
DepthwiseConv2dGPUKernelNCHWSmall<T, kDirection, kKnownFilterWidth,
kKnownFilterHeight, kBlockDepth,
- kKnownEvenHeight>;
+ kKnownEvenHeight, S>;
break;
default:
- LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported";
- return;
+ return errors::InvalidArgument("FORMAT_", ToString(data_format),
+ " is not supported");
}
const int tile_width = args.in_cols + args.filter_cols - 1;
const int tile_height = block_height * 2 + args.filter_rows - 1;
const int tile_pixels = tile_height * tile_width;
const int filter_pixels = args.filter_rows * args.filter_cols;
const int shared_memory_size =
- kBlockDepth * (tile_pixels + filter_pixels) * sizeof(T);
+ kBlockDepth * (tile_pixels + filter_pixels) * sizeof(S);
const int num_outputs = args.out_rows * args.out_cols * block_count;
+ auto device = ctx->eigen_gpu_device();
CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
num_outputs, device, kernel, shared_memory_size,
block_dim.x * block_dim.y * block_dim.z);
kernel<<<config.block_count, block_dim, shared_memory_size,
device.stream()>>>(args, input, filter, output);
+ return Status::OK();
+}
+
+namespace detail {
+template <typename T>
+struct PseudoHalfType {
+ using Type = T;
+};
+template <>
+struct PseudoHalfType<Eigen::half> {
+ using Type = float;
+};
+} // namespace detail
+
+namespace {
+// Maps to float if T is __half, and to T otherwise.
+template <typename T>
+using PseudoHalfType = typename detail::PseudoHalfType<T>::Type;
+
+// Returns whether the context's GPU supports efficient fp16 math.
+bool HasFastHalfMath(OpKernelContext* ctx) {
+ int major, minor;
+ ctx->op_device_context()
+ ->stream()
+ ->parent()
+ ->GetDeviceDescription()
+ .cuda_compute_capability(&major, &minor);
+ auto cuda_arch = major * 100 + minor * 10;
+ // GPUs before sm_53 don't support fp16 math, and sm_61's fp16 math is slow.
+ return cuda_arch >= 530 && cuda_arch != 610;
+}
+} // namespace
+
+template <typename T, DepthwiseConv2dDirection kDirection,
+ int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
+ bool kKnownEvenHeight>
+Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
+ const DepthwiseArgs& args, const T* input,
+ const T* filter, T* output,
+ TensorFormat data_format) {
+#if !defined __CUDA_ARCH__ || __CUDA_ARCH__ >= 530
+ if (HasFastHalfMath(ctx)) {
+ return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
+ kKnownFilterHeight, kBlockDepth,
+ kKnownEvenHeight, T>(
+ ctx, args, input, filter, output, data_format);
+ }
+#endif
+ return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
+ kKnownFilterHeight, kBlockDepth,
+ kKnownEvenHeight, PseudoHalfType<T>>(
+ ctx, args, input, filter, output, data_format);
}
template <typename T, DepthwiseConv2dDirection kDirection,
int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth>
-void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
- const DepthwiseArgs& args, const T* input,
- const T* filter, T* output,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
+ const DepthwiseArgs& args, const T* input,
+ const T* filter, T* output,
+ TensorFormat data_format) {
if (args.in_rows & 1) {
- LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
- kKnownFilterHeight, kBlockDepth, false>(
- device, args, input, filter, output, data_format);
+ return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
+ kKnownFilterHeight, kBlockDepth,
+ false>(ctx, args, input, filter,
+ output, data_format);
} else {
- LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
- kKnownFilterHeight, kBlockDepth, true>(
- device, args, input, filter, output, data_format);
+ return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
+ kKnownFilterHeight, kBlockDepth, true>(
+ ctx, args, input, filter, output, data_format);
}
}
template <typename T, DepthwiseConv2dDirection kDirection,
int kKnownFilterWidth, int kKnownFilterHeight>
-void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
- const DepthwiseArgs& args, const T* input,
- const T* filter, T* output,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
+ const DepthwiseArgs& args, const T* input,
+ const T* filter, T* output,
+ TensorFormat data_format) {
// Maximize (power of two) kBlockDepth while keeping a block within 1024
// threads (2 pixels per thread).
const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols;
if (block_pixels > 256) {
- LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
- kKnownFilterHeight, 2>(
- device, args, input, filter, output, data_format);
+ return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
+ kKnownFilterHeight, 2>(
+ ctx, args, input, filter, output, data_format);
} else if (block_pixels > 128) {
- LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
- kKnownFilterHeight, 4>(
- device, args, input, filter, output, data_format);
+ return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
+ kKnownFilterHeight, 4>(
+ ctx, args, input, filter, output, data_format);
} else {
- LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
- kKnownFilterHeight, 8>(
- device, args, input, filter, output, data_format);
+ return LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
+ kKnownFilterHeight, 8>(
+ ctx, args, input, filter, output, data_format);
}
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
-void LaunchDepthwiseConv2dGPU(const GpuDevice& device,
- const DepthwiseArgs& args, const T* input,
- const T* filter, T* output,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* input, const T* filter, T* output,
+ TensorFormat data_format) {
void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int);
switch (data_format) {
case FORMAT_NHWC:
@@ -691,11 +752,12 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& device,
kKnownDepthMultiplier>;
break;
default:
- LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported";
- return;
+ return errors::InvalidArgument("FORMAT_", ToString(data_format),
+ " is not supported");
}
const int num_outputs =
args.batch * args.out_rows * args.out_cols * args.out_depth;
+ auto device = ctx->eigen_gpu_device();
CudaLaunchConfig config =
GetCudaLaunchConfig(num_outputs, device, kernel, 0, 0);
// The compile-time constant version runs faster with a single block.
@@ -706,26 +768,27 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& device,
kernel<<<std::min(max_block_count, config.block_count),
config.thread_per_block, 0, device.stream()>>>(args, input, filter,
output, num_outputs);
+ return Status::OK();
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
-void LaunchDepthwiseConv2dGPU(const GpuDevice& device,
- const DepthwiseArgs& args, const T* input,
- const T* filter, T* output,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* input, const T* filter, T* output,
+ TensorFormat data_format) {
if (args.depth_multiplier == 1) {
if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
- LaunchDepthwiseConv2dGPUSmall<T, DIRECTION_FORWARD, kKnownFilterWidth,
- kKnownFilterHeight>(
- device, args, input, filter, output, data_format);
- return;
+ return LaunchDepthwiseConv2dGPUSmall<
+ T, DIRECTION_FORWARD, kKnownFilterWidth, kKnownFilterHeight>(
+ ctx, args, input, filter, output, data_format);
}
- LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, 1>(
- device, args, input, filter, output, data_format);
+ return LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight,
+ 1>(ctx, args, input, filter, output,
+ data_format);
} else {
- LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, -1>(
- device, args, input, filter, output, data_format);
+ return LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight,
+ -1>(ctx, args, input, filter, output,
+ data_format);
}
}
@@ -736,18 +799,13 @@ void LaunchDepthwiseConvOp<GpuDevice, T>::operator()(OpKernelContext* ctx,
const T* input,
const T* filter, T* output,
TensorFormat data_format) {
- const GpuDevice& device = ctx->eigen_device<GpuDevice>();
if (args.filter_rows == 3 && args.filter_cols == 3) {
- LaunchDepthwiseConv2dGPU<T, 3, 3>(device, args, input, filter, output,
- data_format);
+ OP_REQUIRES_OK(ctx, LaunchDepthwiseConv2dGPU<T, 3, 3>(
+ ctx, args, input, filter, output, data_format));
} else {
- LaunchDepthwiseConv2dGPU<T, -1, -1>(device, args, input, filter, output,
- data_format);
+ OP_REQUIRES_OK(ctx, LaunchDepthwiseConv2dGPU<T, -1, -1>(
+ ctx, args, input, filter, output, data_format));
}
- auto stream = ctx->op_device_context()->stream();
- OP_REQUIRES(ctx, stream->ok(),
- errors::Internal(
- "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed"));
}
template struct LaunchDepthwiseConvOp<GpuDevice, Eigen::half>;
@@ -904,11 +962,11 @@ __global__ void __launch_bounds__(640, 2)
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
-void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device,
- const DepthwiseArgs& args,
- const T* out_backprop,
- const T* filter, T* in_backprop,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx,
+ const DepthwiseArgs& args,
+ const T* out_backprop,
+ const T* filter, T* in_backprop,
+ TensorFormat data_format) {
void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int);
switch (data_format) {
case FORMAT_NHWC:
@@ -920,38 +978,39 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device,
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>;
break;
default:
- LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported";
- return;
+ return errors::InvalidArgument("FORMAT_", ToString(data_format),
+ " is not supported");
}
const int num_in_backprop =
args.batch * args.in_rows * args.in_cols * args.in_depth;
+ auto device = ctx->eigen_gpu_device();
CudaLaunchConfig config =
GetCudaLaunchConfig(num_in_backprop, device, kernel, 0, 0);
kernel<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
+ return Status::OK();
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
-void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device,
- const DepthwiseArgs& args,
- const T* out_backprop,
- const T* filter, T* in_backprop,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx,
+ const DepthwiseArgs& args,
+ const T* out_backprop,
+ const T* filter, T* in_backprop,
+ TensorFormat data_format) {
if (args.depth_multiplier == 1) {
if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
- LaunchDepthwiseConv2dGPUSmall<T, DIRECTION_BACKWARD, kKnownFilterWidth,
- kKnownFilterHeight>(
- device, args, out_backprop, filter, in_backprop, data_format);
- return;
+ return LaunchDepthwiseConv2dGPUSmall<
+ T, DIRECTION_BACKWARD, kKnownFilterWidth, kKnownFilterHeight>(
+ ctx, args, out_backprop, filter, in_backprop, data_format);
}
- LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
- kKnownFilterHeight, 1>(
- device, args, out_backprop, filter, in_backprop, data_format);
+ return LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, 1>(
+ ctx, args, out_backprop, filter, in_backprop, data_format);
} else {
- LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
- kKnownFilterHeight, -1>(
- device, args, out_backprop, filter, in_backprop, data_format);
+ return LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, -1>(
+ ctx, args, out_backprop, filter, in_backprop, data_format);
}
}
@@ -960,19 +1019,15 @@ template <typename T>
void LaunchDepthwiseConvBackpropInputOp<GpuDevice, T>::operator()(
OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
const T* filter, T* in_backprop, TensorFormat data_format) {
- const GpuDevice& device = ctx->eigen_device<GpuDevice>();
if (args.filter_rows == 3 && args.filter_cols == 3) {
- LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
- device, args, out_backprop, filter, in_backprop, data_format);
+ OP_REQUIRES_OK(
+ ctx, LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
+ ctx, args, out_backprop, filter, in_backprop, data_format));
} else {
- LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
- device, args, out_backprop, filter, in_backprop, data_format);
+ OP_REQUIRES_OK(
+ ctx, LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
+ ctx, args, out_backprop, filter, in_backprop, data_format));
}
- auto stream = ctx->op_device_context()->stream();
- OP_REQUIRES(ctx, stream->ok(),
- errors::Internal("Launch of gpu kernel for "
- "DepthwiseConv2dBackpropInp"
- "utGPULaunch failed"));
}
template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, Eigen::half>;
@@ -1111,15 +1166,18 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
// up in global memory using atomics.
// Requirements: threads per block must be multiple of 32 and <= launch_bounds,
// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth.
+// T is the tensors' data type. S is the math type the kernel uses. This is the
+// same as T for all cases but pseudo half (which has T=Eigen::half, S=float).
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- int kBlockDepth, int kAccumPixels>
+ int kBlockDepth, int kAccumPixels, typename S>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z));
// Holds block plus halo and filter data for blockDim.x depths.
- extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
- T* const shared_data = reinterpret_cast<T*>(shared_memory);
+ extern __shared__ __align__(8) unsigned char shared_memory[];
+ static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
const int in_height = args.in_rows;
@@ -1169,7 +1227,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Initialize tile, in particular the padding and accumulator.
for (int i = thread_idx; i < tile_size + accum_size; i += block_size) {
- shared_data[i] = T(0);
+ shared_data[i] = S();
}
__syncthreads();
@@ -1203,10 +1261,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
if (channel_in_range) {
const T* const in_ptr = inout_offset + input;
- T* const tile_ptr = tile_idx + shared_data;
- tile_ptr[0] = ldg(in_ptr);
+ S* const tile_ptr = tile_idx + shared_data;
+ tile_ptr[0] = static_cast<S>(ldg(in_ptr));
if (!skip_second) {
- tile_ptr[tile_offset] = ldg(tensor_offset + in_ptr);
+ tile_ptr[tile_offset] = static_cast<S>(ldg(tensor_offset + in_ptr));
}
}
@@ -1216,14 +1274,15 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
if (channel_in_range) {
const T* const out_ptr = inout_offset + output;
- const T out1 = ldg(out_ptr);
- const T out2 = skip_second ? T(0) : ldg(tensor_offset + out_ptr);
+ const S out1 = static_cast<S>(ldg(out_ptr));
+ const S out2 =
+ skip_second ? S() : static_cast<S>(ldg(tensor_offset + out_ptr));
int shared_offset = data_idx;
- T* accum_ptr = accum_offset + shared_data;
+ S* accum_ptr = accum_offset + shared_data;
UNROLL for (int r = 0; r < filter_height; ++r) {
UNROLL for (int c = 0; c < filter_width; ++c) {
- const T* const tile_ptr = shared_offset + shared_data;
- T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
+ const S* const tile_ptr = shared_offset + shared_data;
+ S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockDepth; delta /= 2) {
val += CudaShuffleXorSync(active_threads, val, delta);
@@ -1241,18 +1300,18 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- const T* const accum_data = tile_size + shared_data;
+ const S* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
const int filter_idx = i / kAccumPixels;
const int filter_pix = filter_idx / kBlockDepth;
const int filter_channel = filter_idx % kBlockDepth + start_channel;
const int filter_offset = filter_pix * in_depth + filter_channel;
if (filter_channel < in_depth) {
- T val = accum_data[i];
+ S val = accum_data[i];
// Warp-accumulate the pixels of the same depth from the accumulator.
val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
- CudaAtomicAdd(filter_offset + filter, val);
+ CudaAtomicAdd(filter_offset + filter, static_cast<T>(val));
}
}
}
@@ -1382,14 +1441,15 @@ __global__ void __launch_bounds__(640, 2)
// Requirements: threads per block must be multiple of 32 and <= launch_bounds,
// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- int kBlockDepth, int kAccumPixels>
+ int kBlockDepth, int kAccumPixels, typename S>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x));
// Holds block plus halo and filter data for blockDim.z depths.
- extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
- T* const shared_data = reinterpret_cast<T*>(shared_memory);
+ extern __shared__ __align__(8) unsigned char shared_memory[];
+ static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
const int in_height = args.in_rows;
@@ -1438,7 +1498,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Initialize tile, in particular the padding and accumulator.
for (int i = thread_idx; i < tile_size + accum_size; i += block_size) {
- shared_data[i] = T(0);
+ shared_data[i] = S();
}
__syncthreads();
@@ -1468,10 +1528,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
if (channel_in_range) {
const T* const in_ptr = inout_offset + input;
- T* const tile_ptr = tile_idx + shared_data;
- tile_ptr[0] = ldg(in_ptr);
+ S* const tile_ptr = tile_idx + shared_data;
+ tile_ptr[0] = static_cast<S>(ldg(in_ptr));
if (!skip_second) {
- tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
+ tile_ptr[tile_offset] = static_cast<S>(ldg(block_pixels + in_ptr));
}
}
@@ -1481,14 +1541,15 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
if (channel_in_range) {
const T* const out_ptr = inout_offset + output;
- const T out1 = ldg(out_ptr);
- const T out2 = skip_second ? T(0) : ldg(block_pixels + out_ptr);
+ const S out1 = static_cast<S>(ldg(out_ptr));
+ const S out2 =
+ skip_second ? S() : static_cast<S>(ldg(block_pixels + out_ptr));
int shared_offset = data_idx;
- T* accum_ptr = accum_offset + shared_data;
+ S* accum_ptr = accum_offset + shared_data;
UNROLL for (int r = 0; r < filter_height; ++r) {
UNROLL for (int c = 0; c < filter_width; ++c) {
- const T* const tile_ptr = shared_offset + shared_data;
- T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
+ const S* const tile_ptr = shared_offset + shared_data;
+ S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockDepth; delta > 0; delta /= 2) {
val += CudaShuffleXorSync(active_threads, val, delta);
@@ -1506,7 +1567,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- const T* const accum_data = tile_size + shared_data;
+ const S* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
const int filter_idx = i / kAccumPixels;
const int filter_pix = filter_idx / kBlockDepth;
@@ -1514,11 +1575,11 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
(channel + filter_idx % kBlockDepth) % in_depth;
const int filter_offset = filter_pix * in_depth + filter_channel;
if (filter_channel < in_depth) {
- T val = accum_data[i];
+ S val = accum_data[i];
// Warp-accumulate pixels of the same depth from the accumulator.
val = WarpSumReduce<kAccumPixels>(val);
if (!(thread_idx & kAccumPixels - 1)) {
- CudaAtomicAdd(filter_offset + filter, val);
+ CudaAtomicAdd(filter_offset + filter, static_cast<T>(val));
}
}
}
@@ -1526,19 +1587,20 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
- int kBlockDepth, int kAccumPixels>
-bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
- const GpuDevice& device, const DepthwiseArgs& args, const int block_height,
+ int kBlockDepth, int kAccumPixels, typename S>
+Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height,
const T* out_backprop, const T* input, T* filter_backprop,
TensorFormat data_format) {
+ auto device = ctx->eigen_gpu_device();
const int tile_width = args.in_cols + args.filter_cols - 1;
const int tile_height = block_height * 2 + args.filter_rows - 1;
const int tile_pixels = tile_height * tile_width;
const int filter_pixels = args.filter_rows * args.filter_cols;
const int shared_memory_size =
- kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T);
+ kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(S);
if (shared_memory_size > device.sharedMemPerBlock()) {
- return false;
+ return errors::FailedPrecondition("Not enough shared memory");
}
dim3 block_dim;
@@ -1550,18 +1612,20 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
block_count =
args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth;
kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels,
+ S>;
break;
case FORMAT_NCHW:
block_dim = dim3(args.in_cols, block_height, kBlockDepth);
block_count =
DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
- T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels,
+ S>;
break;
default:
- LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported";
- return false;
+ return errors::InvalidArgument("FORMAT_", ToString(data_format),
+ " is not supported");
}
const int num_out_backprop = args.out_rows * args.out_cols * block_count;
CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
@@ -1569,13 +1633,33 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
block_dim.x * block_dim.y * block_dim.z);
kernel<<<config.block_count, block_dim, shared_memory_size,
device.stream()>>>(args, out_backprop, input, filter_backprop);
- return true;
+ return Status::OK();
+}
+
+template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
+ int kBlockDepth, int kAccumPixels>
+Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height,
+ const T* out_backprop, const T* input, T* filter_backprop,
+ TensorFormat data_format) {
+#if !defined __CUDA_ARCH__ || __CUDA_ARCH__ >= 530
+ if (HasFastHalfMath(ctx)) {
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels, T>(
+ ctx, args, block_height, out_backprop, input, filter_backprop,
+ data_format);
+ }
+#endif
+ return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
+ T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels,
+ PseudoHalfType<T>>(ctx, args, block_height, out_backprop, input,
+ filter_backprop, data_format);
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kBlockDepth>
-bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
- const GpuDevice& device, const DepthwiseArgs& args, const int block_height,
+Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height,
const T* out_backprop, const T* input, T* filter_backprop,
TensorFormat data_format) {
// Minimize (power of two) kAccumPixels, while satisfying
@@ -1584,24 +1668,24 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
if (block_pixels > 512) {
return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 32>(
- device, args, block_height, out_backprop, input, filter_backprop,
+ ctx, args, block_height, out_backprop, input, filter_backprop,
data_format);
} else if (block_pixels > 256) {
return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 16>(
- device, args, block_height, out_backprop, input, filter_backprop,
+ ctx, args, block_height, out_backprop, input, filter_backprop,
data_format);
} else {
return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 8>(
- device, args, block_height, out_backprop, input, filter_backprop,
+ ctx, args, block_height, out_backprop, input, filter_backprop,
data_format);
}
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
-bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
- const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop,
+Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
const T* input, T* filter_backprop, TensorFormat data_format) {
// Maximize (power of two) kBlockDepth while keeping a block within 1024
// threads (2 pixels per thread).
@@ -1621,37 +1705,35 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
}
if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_height)) {
- return false;
+ return errors::FailedPrecondition("Cannot launch this configuration");
}
switch (block_depth) {
case 8:
return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
T, kKnownFilterWidth, kKnownFilterHeight, 8>(
- device, args, block_height, out_backprop, input, filter_backprop,
+ ctx, args, block_height, out_backprop, input, filter_backprop,
data_format);
case 4:
return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
T, kKnownFilterWidth, kKnownFilterHeight, 4>(
- device, args, block_height, out_backprop, input, filter_backprop,
+ ctx, args, block_height, out_backprop, input, filter_backprop,
data_format);
case 2:
return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
T, kKnownFilterWidth, kKnownFilterHeight, 2>(
- device, args, block_height, out_backprop, input, filter_backprop,
+ ctx, args, block_height, out_backprop, input, filter_backprop,
data_format);
default:
- return false;
+ return errors::InvalidArgument("Unexpected block depth");
}
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
-void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device,
- const DepthwiseArgs& args,
- const T* out_backprop,
- const T* input, T* filter_backprop,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dBackpropFilterGPU(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
+ const T* input, T* filter_backprop, TensorFormat data_format) {
void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int);
switch (data_format) {
case FORMAT_NHWC:
@@ -1663,37 +1745,38 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device,
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>;
break;
default:
- LOG(ERROR) << "FORMAT_" << ToString(data_format) << " is not supported";
- return;
+ return errors::InvalidArgument("FORMAT_", ToString(data_format),
+ " is not supported");
}
const int num_out_backprop =
args.batch * args.out_rows * args.out_cols * args.out_depth;
+ auto device = ctx->eigen_gpu_device();
CudaLaunchConfig config =
GetCudaLaunchConfig(num_out_backprop, device, kernel, 0, 0);
kernel<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
+ return Status::OK();
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
-void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device,
- const DepthwiseArgs& args,
- const T* out_backprop,
- const T* input, T* filter_backprop,
- TensorFormat data_format) {
+Status LaunchDepthwiseConv2dBackpropFilterGPU(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
+ const T* input, T* filter_backprop, TensorFormat data_format) {
if (args.depth_multiplier == 1) {
if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
kKnownFilterHeight>(
- device, args, out_backprop, input, filter_backprop, data_format)) {
- return;
+ ctx, args, out_backprop, input, filter_backprop, data_format)
+ .ok()) {
+ return Status::OK();
}
- LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
- kKnownFilterHeight, 1>(
- device, args, out_backprop, input, filter_backprop, data_format);
+ return LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, 1>(
+ ctx, args, out_backprop, input, filter_backprop, data_format);
} else {
- LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
- kKnownFilterHeight, -1>(
- device, args, out_backprop, input, filter_backprop, data_format);
+ return LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
+ kKnownFilterHeight, -1>(
+ ctx, args, out_backprop, input, filter_backprop, data_format);
}
}
@@ -1702,7 +1785,6 @@ template <typename T>
void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()(
OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
const T* input, T* filter_backprop, TensorFormat data_format) {
- const GpuDevice& device = ctx->eigen_device<GpuDevice>();
auto stream = ctx->op_device_context()->stream();
// Initialize the results to 0.
@@ -1712,16 +1794,14 @@ void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()(
stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
if (args.filter_rows == 3 && args.filter_cols == 3) {
- LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
- device, args, out_backprop, input, filter_backprop, data_format);
+ OP_REQUIRES_OK(
+ ctx, LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
+ ctx, args, out_backprop, input, filter_backprop, data_format));
} else {
- LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
- device, args, out_backprop, input, filter_backprop, data_format);
+ OP_REQUIRES_OK(
+ ctx, LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
+ ctx, args, out_backprop, input, filter_backprop, data_format));
}
- OP_REQUIRES(ctx, stream->ok(),
- errors::Internal("Launch of gpu kernel for "
- "DepthwiseConv2dBackpropFil"
- "terGPULaunch failed"));
}
template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, Eigen::half>;