aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-04 03:19:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-04 03:24:08 -0700
commit07356b48e4b374efd406fd142faa77cfa4db05e9 (patch)
treef5049f7ef36486535e386934f3dfc48f72831f45
parent0302320e11c7561cafac1cc279fea87de02b0cf9 (diff)
Exposing launchpad for conv2d backprop, and unify launchpads for conv2d and depthwise_conv to match example in documentation (see ./extend/adding_an_op.md)
PiperOrigin-RevId: 167480081
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc690
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc716
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h37
-rw-r--r--tensorflow/core/kernels/conv_ops.cc42
-rw-r--r--tensorflow/core/kernels/conv_ops.h32
-rw-r--r--tensorflow/core/kernels/depthwise_conv_grad_op.cc74
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.cc42
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.h47
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc103
9 files changed, 922 insertions, 861 deletions
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 65514937f4..8eb705b2e5 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -91,6 +91,20 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+template <typename T>
+struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& out_backprop, const Tensor& input,
+ int row_stride, int col_stride, const Padding& padding,
+ Tensor* filter_backprop, TensorFormat data_format) {
+ const CPUDevice& d = ctx->eigen_device<CPUDevice>();
+ functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
+ d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
+ out_backprop.tensor<T, 4>(), filter_backprop->dim_size(0),
+ filter_backprop->dim_size(1), row_stride, col_stride);
+ }
+};
+
#ifdef TENSORFLOW_USE_LIBXSMM
template <typename Device, class T>
struct LaunchXsmmBackwardFilter {
@@ -237,11 +251,9 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
}
#endif
- functor::SpatialConvolutionBackwardKernel<Device, T>()(
- context->eigen_device<Device>(), filter_backprop->tensor<T, 4>(),
- input.tensor<T, 4>(), out_backprop.tensor<T, 4>(),
- dims.spatial_dims[0].filter_size, dims.spatial_dims[1].filter_size,
- dims.spatial_dims[0].stride, dims.spatial_dims[1].stride);
+ LaunchConv2DBackpropInputOp<Device, T>()(
+ context, false, false, out_backprop, input, dims.spatial_dims[0].stride,
+ dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_);
}
private:
@@ -495,15 +507,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
- cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmType;
- using perftools::gputools::dnn::ProfileResult;
- using perftools::gputools::dnn::kDefaultAlgorithm;
const Tensor& input = context->input(0);
const Tensor& filter_sizes = context->input(1);
const Tensor& out_backprop = context->input(2);
@@ -512,352 +519,373 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
errors::InvalidArgument(
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
filter_sizes.dims()));
- const TensorShape& input_shape = input.shape();
TensorShape filter_shape;
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
filter_sizes.vec<int32>(), &filter_shape));
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(context,
- ConvBackpropComputeDimensions(
- "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2,
- input.shape(), filter_shape, out_backprop.shape(),
- strides_, padding_, data_format_, &dims));
-
Tensor* filter_backprop = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
- const int padding_rows =
- (padding_ == VALID)
- ? 0
- : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
- dims.spatial_dims[0].stride +
- dims.spatial_dims[0].filter_size -
- dims.spatial_dims[0].input_size);
- const int padding_cols =
- (padding_ == VALID)
- ? 0
- : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
- dims.spatial_dims[1].stride +
- dims.spatial_dims[1].filter_size -
- dims.spatial_dims[1].input_size);
-
- // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
- // calling it when that is true. Remove this check when (if?) cuDNN starts
- // supporting different padding.
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
-
- auto* stream = context->op_device_context()->stream();
- OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
-
- if (!use_cudnn_) {
- context->SetStatus(errors::Unimplemented(
- "Conv2DBackprop for GPU is not currently supported "
- "without cudnn"));
- return;
- }
+ // For now we take the stride from the second and third dimensions only (we
+ // do not support striding on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
- if (!cudnn_disable_conv_1x1_optimization_ &&
- dims.spatial_dims[0].filter_size == 1 &&
- dims.spatial_dims[1].filter_size == 1 &&
- dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = dims.in_depth;
- const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size *
- dims.spatial_dims[1].input_size;
- const uint64 n = dims.out_depth;
-
- // The shape of output backprop is
- // [batch, out_rows, out_cols, out_depth]
- // From cublas's perspective, it is: n x k
- auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
- out_backprop.template flat<T>().size());
-
- // The shape of input is
- // [batch, in_rows, in_cols, in_depth],
- // From cublas's perspective, it is: m x k
- auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
- input.template flat<T>().size());
-
- // the shape of the filter backprop from the conv_2d should be
- // [1, 1, in_depth, out_depth]
- // From cublas's perspective, it is: n x m
- auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
- filter_backprop->template flat<T>().size());
-
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose,
- n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
- .ok();
- if (!blas_launch_status) {
- context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
- ", n=", n, ", k=", k));
- }
- return;
- } else if (dims.spatial_dims[0].filter_size ==
- dims.spatial_dims[0].input_size &&
- dims.spatial_dims[1].filter_size ==
- dims.spatial_dims[1].input_size &&
- padding_ == VALID && data_format_ == FORMAT_NHWC) {
- // The input data and filter have the same height/width, so call cublas
- // directly.
- const uint64 m = dims.spatial_dims[0].input_size *
- dims.spatial_dims[1].input_size * dims.in_depth;
- const uint64 k = dims.batch_size;
- const uint64 n = dims.out_depth;
-
- auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
- input.template flat<T>().size());
- auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
- out_backprop.template flat<T>().size());
- auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
- filter_backprop->template flat<T>().size());
-
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose,
- n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
- .ok();
- if (!blas_launch_status) {
- context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
- ", n=", n, ", k=", k));
- }
- return;
- }
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
+ stride_rows, stride_cols, padding_, filter_backprop,
+ data_format_);
+ }
- Tensor compatible_input;
- if (rows_odd || cols_odd) {
- // If a padding dimension is odd, we have one more element on the right
- // side or the bottom side. This is unsupported in cudnn. Therefore,
- // we pad that extra element and make it compatible.
- OP_REQUIRES_OK(
- context,
- context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format_, dims.batch_size,
- dims.spatial_dims[0].input_size + rows_odd,
- dims.spatial_dims[1].input_size + cols_odd,
- dims.in_depth),
- &compatible_input));
-
- functor::PadInput<GPUDevice, T, int, 4>()(
- context->template eigen_device<GPUDevice>(),
- To32Bit(input.tensor<T, 4>()), {{0, 0}}, {{rows_odd, cols_odd}},
- To32Bit(compatible_input.tensor<T, 4>()), data_format_);
- } else {
- compatible_input = input;
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ bool use_cudnn_;
+ TensorFormat data_format_;
+ LaunchConv2DBackpropFilterOp<Device, T> launcher_;
+ bool cudnn_use_autotune_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp);
+};
+
+template <typename T>
+void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
+ OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& out_backprop, const Tensor& input, int row_stride,
+ int col_stride, const Padding& padding, Tensor* filter_backprop,
+ TensorFormat data_format) {
+ using perftools::gputools::dnn::AlgorithmConfig;
+ using perftools::gputools::dnn::AlgorithmType;
+ using perftools::gputools::dnn::ProfileResult;
+
+ std::vector<int32> strides(4, 1);
+ strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
+ strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
+ TensorShape filter_shape = filter_backprop->shape();
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
+ "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2,
+ input.shape(), filter_shape, out_backprop.shape(),
+ strides, padding, data_format, &dims));
+
+ const int padding_rows =
+ (padding == VALID)
+ ? 0
+ : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
+ dims.spatial_dims[0].stride +
+ dims.spatial_dims[0].filter_size -
+ dims.spatial_dims[0].input_size);
+ const int padding_cols =
+ (padding == VALID)
+ ? 0
+ : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
+ dims.spatial_dims[1].stride +
+ dims.spatial_dims[1].filter_size -
+ dims.spatial_dims[1].input_size);
+
+ // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
+ // calling it when that is true. Remove this check when (if?) cuDNN starts
+ // supporting different padding.
+ bool rows_odd = (padding_rows % 2 != 0);
+ bool cols_odd = (padding_cols % 2 != 0);
+
+ auto* stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+
+ if (!use_cudnn) {
+ ctx->SetStatus(errors::Unimplemented(
+ "Conv2DBackprop for GPU is not currently supported "
+ "without cudnn"));
+ return;
+ }
+
+ bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
+ if (!cudnn_disable_conv_1x1_optimization_ &&
+ dims.spatial_dims[0].filter_size == 1 &&
+ dims.spatial_dims[1].filter_size == 1 &&
+ dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
+ data_format == FORMAT_NHWC) {
+ const uint64 m = dims.in_depth;
+ const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size;
+ const uint64 n = dims.out_depth;
+
+ // The shape of output backprop is
+ // [batch, out_rows, out_cols, out_depth]
+ // From cublas's perspective, it is: n x k
+ auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+
+ // The shape of input is
+ // [batch, in_rows, in_cols, in_depth],
+ // From cublas's perspective, it is: m x k
+ auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+
+ // the shape of the filter backprop from the conv_2d should be
+ // [1, 1, in_depth, out_depth]
+ // From cublas's perspective, it is: n x m
+ auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
+ filter_backprop->template flat<T>().size());
+
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
+ perftools::gputools::blas::Transpose::kTranspose, n,
+ m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
}
+ return;
+ } else if (dims.spatial_dims[0].filter_size ==
+ dims.spatial_dims[0].input_size &&
+ dims.spatial_dims[1].filter_size ==
+ dims.spatial_dims[1].input_size &&
+ padding == VALID && data_format == FORMAT_NHWC) {
+ // The input data and filter have the same height/width, so call cublas
+ // directly.
+ const uint64 m = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size * dims.in_depth;
+ const uint64 k = dims.batch_size;
+ const uint64 n = dims.out_depth;
+
+ auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
+ filter_backprop->template flat<T>().size());
+
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
+ perftools::gputools::blas::Transpose::kTranspose, n,
+ m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
+ }
+ return;
+ }
- CHECK(padding_rows >= 0 && padding_cols >= 0)
- << "Negative row or col paddings: (" << padding_rows << ", "
- << padding_cols << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc;
- input_desc.set_count(dims.batch_size)
- .set_height(GetTensorDim(compatible_input, data_format_, 'H'))
- .set_width(GetTensorDim(compatible_input, data_format_, 'W'))
- .set_feature_map_count(dims.in_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
- output_desc.set_count(dims.batch_size)
- .set_height(dims.spatial_dims[0].output_size)
- .set_width(dims.spatial_dims[1].output_size)
- .set_feature_map_count(dims.out_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
- .set_input_filter_width(dims.spatial_dims[1].filter_size)
- .set_input_feature_map_count(dims.in_depth)
- .set_output_feature_map_count(dims.out_depth);
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
- .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
- .set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
-
- // NOTE(zhengxq):
- // cuDNN only supports the following layouts :
- // Input : B x D x R x C
- // Filter : OD x ID x R x C
- // Whereas, we have
- // Input : B x R x C x D
- // Filter : R x C x ID x OD
- // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
- // The first TransformDepth performs
- // (B x R x C x D) => (B x D x R x C).
- // Since the tensor returned from cuDNN is B x D x R x C also,
- // the second TransformDepth performs
- // (B x D x R x C) => (B x R x C x D).
-
- Tensor pre_transformed_filter_backprop;
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape({dims.out_depth, dims.in_depth,
- dims.spatial_dims[0].filter_size,
- dims.spatial_dims[1].filter_size}),
- &pre_transformed_filter_backprop));
-
- Tensor transformed_out_backprop;
- if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = ShapeFromFormat(
- FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
- dims.spatial_dims[1].output_size, dims.out_depth);
- if (dims.out_depth > 1) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value, nchw_shape,
- &transformed_out_backprop));
- functor::NHWCToNCHW<Device, T, 4>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
- transformed_out_backprop.tensor<T, 4>());
- } else {
- // If depth <= 1, just reshape.
- CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
- }
+ Tensor compatible_input;
+ if (rows_odd || cols_odd) {
+ // If a padding dimension is odd, we have one more element on the right
+ // side or the bottom side. This is unsupported in cudnn. Therefore,
+ // we pad that extra element and make it compatible.
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format, dims.batch_size,
+ dims.spatial_dims[0].input_size + rows_odd,
+ dims.spatial_dims[1].input_size + cols_odd,
+ dims.in_depth),
+ &compatible_input));
+
+ functor::PadInput<GPUDevice, T, int, 4>()(
+ ctx->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 4>()),
+ {{0, 0}}, {{rows_odd, cols_odd}},
+ To32Bit(compatible_input.tensor<T, 4>()), data_format);
+ } else {
+ compatible_input = input;
+ }
+
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(dims.batch_size)
+ .set_height(GetTensorDim(compatible_input, data_format, 'H'))
+ .set_width(GetTensorDim(compatible_input, data_format, 'W'))
+ .set_feature_map_count(dims.in_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(dims.batch_size)
+ .set_height(dims.spatial_dims[0].output_size)
+ .set_width(dims.spatial_dims[1].output_size)
+ .set_feature_map_count(dims.out_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
+ .set_input_filter_width(dims.spatial_dims[1].filter_size)
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
+ .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
+
+ // NOTE(zhengxq):
+ // cuDNN only supports the following layouts :
+ // Input : B x D x R x C
+ // Filter : OD x ID x R x C
+ // Whereas, we have
+ // Input : B x R x C x D
+ // Filter : R x C x ID x OD
+ // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
+ // The first TransformDepth performs
+ // (B x R x C x D) => (B x D x R x C).
+ // Since the tensor returned from cuDNN is B x D x R x C also,
+ // the second TransformDepth performs
+ // (B x D x R x C) => (B x R x C x D).
+
+ Tensor pre_transformed_filter_backprop;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[1].filter_size}),
+ &pre_transformed_filter_backprop));
+
+ Tensor transformed_out_backprop;
+ if (data_format == FORMAT_NHWC) {
+ TensorShape nchw_shape = ShapeFromFormat(
+ FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
+ dims.spatial_dims[1].output_size, dims.out_depth);
+ if (dims.out_depth > 1) {
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
+ &transformed_out_backprop));
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
+ transformed_out_backprop.tensor<T, 4>());
} else {
- transformed_out_backprop = out_backprop;
+ // If depth <= 1, just reshape.
+ CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
}
+ } else {
+ transformed_out_backprop = out_backprop;
+ }
- Tensor transformed_input;
- if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = ShapeFromFormat(
- FORMAT_NCHW, GetTensorDim(compatible_input, data_format_, 'N'),
- GetTensorDim(compatible_input, data_format_, 'H'),
- GetTensorDim(compatible_input, data_format_, 'W'),
- GetTensorDim(compatible_input, data_format_, 'C'));
- if (nchw_shape.dim_size(1) > 1) {
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- nchw_shape, &transformed_input));
- functor::NHWCToNCHW<Device, T, 4>()(
- context->eigen_device<Device>(),
- const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
- transformed_input.tensor<T, 4>());
- } else {
- // If depth <= 1, just reshape.
- CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
- }
+ Tensor transformed_input;
+ if (data_format == FORMAT_NHWC) {
+ TensorShape nchw_shape = ShapeFromFormat(
+ FORMAT_NCHW, GetTensorDim(compatible_input, data_format, 'N'),
+ GetTensorDim(compatible_input, data_format, 'H'),
+ GetTensorDim(compatible_input, data_format, 'W'),
+ GetTensorDim(compatible_input, data_format, 'C'));
+ if (nchw_shape.dim_size(1) > 1) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ nchw_shape, &transformed_input));
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
} else {
- transformed_input = compatible_input;
+ // If depth <= 1, just reshape.
+ CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
}
+ } else {
+ transformed_input = compatible_input;
+ }
- auto out_backprop_ptr =
- AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
- transformed_out_backprop.template flat<T>().size());
- auto filter_backprop_ptr = AsDeviceMemory(
- pre_transformed_filter_backprop.template flat<T>().data(),
- pre_transformed_filter_backprop.template flat<T>().size());
- auto input_ptr =
- AsDeviceMemory(transformed_input.template flat<T>().data(),
- transformed_input.template flat<T>().size());
-
- static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
- );
- int device_id = stream->parent()->device_ordinal();
- DataType dtype = input.dtype();
- ConvParameters conv_parameters = {
- dims.batch_size, // batch
- dims.in_depth, // in_depths
- {{input_desc.height(), // in_rows
- input_desc.width()}}, // in_cols
- dims.out_depth, // out_depths
- {{dims.spatial_dims[0].filter_size, // filter_rows
- dims.spatial_dims[1].filter_size}}, // filter_cols
- {{dims.spatial_dims[0].stride, // stride_rows
- dims.spatial_dims[1].stride}}, // stride_cols
- {{padding_rows, // padding_rows
- padding_cols}}, // padding_cols
- dtype, // tensor datatype
- device_id, // device_id
- };
- AlgorithmConfig algorithm_config;
- if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find(
- conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
- ProfileResult best_result;
- ProfileResult best_result_no_scratch;
- for (auto profile_algorithm : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- CudnnScratchAllocator scratch_allocator(
- ConvolveBackwardFilterScratchSize, context);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardFilterWithAlgorithm(
- input_desc, input_ptr, output_desc, out_backprop_ptr,
- conv_desc, filter_desc, &filter_backprop_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ auto out_backprop_ptr =
+ AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
+ transformed_out_backprop.template flat<T>().size());
+ auto filter_backprop_ptr =
+ AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(),
+ pre_transformed_filter_backprop.template flat<T>().size());
+ auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(),
+ transformed_input.template flat<T>().size());
+
+ static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
+ );
+ int device_id = stream->parent()->device_ordinal();
+ DataType dtype = input.dtype();
+ ConvParameters conv_parameters = {
+ dims.batch_size, // batch
+ dims.in_depth, // in_depths
+ {{input_desc.height(), // in_rows
+ input_desc.width()}}, // in_cols
+ dims.out_depth, // out_depths
+ {{dims.spatial_dims[0].filter_size, // filter_rows
+ dims.spatial_dims[1].filter_size}}, // filter_cols
+ {{dims.spatial_dims[0].stride, // stride_rows
+ dims.spatial_dims[1].stride}}, // stride_cols
+ {{padding_rows, // padding_rows
+ padding_cols}}, // padding_cols
+ dtype, // tensor datatype
+ device_id, // device_id
+ };
+ AlgorithmConfig algorithm_config;
+ if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
+ conv_parameters, &algorithm_config)) {
+ std::vector<AlgorithmType> algorithms;
+ CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
+ ProfileResult best_result;
+ ProfileResult best_result_no_scratch;
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
+ ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_desc, input_ptr, output_desc, out_backprop_ptr,
+ conv_desc, filter_desc, &filter_backprop_ptr,
+ &scratch_allocator, AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
- OP_REQUIRES(context,
- best_result.is_valid() || best_result_no_scratch.is_valid(),
- errors::NotFound("No algorithm worked!"));
- if (best_result.is_valid()) {
- algorithm_config.set_algorithm(best_result.algorithm());
- }
- if (best_result_no_scratch.is_valid()) {
- algorithm_config.set_algorithm_no_scratch(
- best_result_no_scratch.algorithm());
- }
- AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
- algorithm_config);
}
- CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
- context);
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardFilterWithAlgorithm(
- input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
- filter_desc, &filter_backprop_ptr, &scratch_allocator,
- algorithm_config, nullptr)
- .ok();
-
- if (!cudnn_launch_status) {
- context->SetStatus(errors::Internal(
- "cuDNN Backward Filter function launch failure : input shape(",
- input_shape.DebugString(), ") filter shape(",
- filter_shape.DebugString(), ")"));
- return;
+ OP_REQUIRES(ctx,
+ best_result.is_valid() || best_result_no_scratch.is_valid(),
+ errors::NotFound("No algorithm worked!"));
+ if (best_result.is_valid()) {
+ algorithm_config.set_algorithm(best_result.algorithm());
}
-
- auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
- functor::ReverseTransformFilter<Device, T, 4>()(
- context->eigen_device<Device>(),
- toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
- filter_backprop->tensor<T, 4>());
+ if (best_result_no_scratch.is_valid()) {
+ algorithm_config.set_algorithm_no_scratch(
+ best_result_no_scratch.algorithm());
+ }
+ AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
+ algorithm_config);
+ }
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
+ ctx);
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
+ filter_desc, &filter_backprop_ptr, &scratch_allocator,
+ algorithm_config, nullptr)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "cuDNN Backward Filter function launch failure : input shape(",
+ input.shape().DebugString(), ") filter shape(",
+ filter_shape.DebugString(), ")"));
+ return;
}
- private:
- std::vector<int32> strides_;
- Padding padding_;
- bool use_cudnn_;
- TensorFormat data_format_;
- bool cudnn_use_autotune_;
- bool cudnn_disable_conv_1x1_optimization_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp);
-};
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::ReverseTransformFilter<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
+ filter_backprop->tensor<T, 4>());
+}
// Forward declarations of the functor specializations for GPU.
namespace functor {
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index a5a9549a2f..ce561aa99c 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -97,29 +97,17 @@ typedef Eigen::GpuDevice GPUDevice;
// for CPU for now since nvcc times out when trying to compile them.
// TODO(yangke): enable them for GPUs when we have a faster compiler.
-template <typename Device, class T>
-struct LaunchBackwardInputConvolution {
- bool operator()(OpKernelContext* context, const Device&,
- typename TTypes<T, 4>::Tensor,
- typename TTypes<T, 4>::ConstTensor,
- typename TTypes<T, 4>::ConstTensor, int, int, int, int,
- TensorFormat) const {
- return false;
- }
-};
-
-template <>
-struct LaunchBackwardInputConvolution<CPUDevice, float> {
- bool operator()(OpKernelContext* context, const CPUDevice& d,
- typename TTypes<float, 4>::Tensor input_backward,
- typename TTypes<float, 4>::ConstTensor kernel,
- typename TTypes<float, 4>::ConstTensor output_backward,
- int input_rows, int input_cols, int row_stride,
- int col_stride, TensorFormat data_format) const {
- functor::SpatialConvolutionBackwardInput<CPUDevice, float>()(
- d, input_backward, kernel, output_backward, input_rows, input_cols,
- row_stride, col_stride);
- return true;
+template <typename T>
+struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& out_backprop, const Tensor& filter,
+ int row_stride, int col_stride, const Padding& padding,
+ Tensor* in_backprop, TensorFormat data_format) {
+ const CPUDevice& d = ctx->eigen_device<CPUDevice>();
+ functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
+ d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
+ out_backprop.tensor<T, 4>(), in_backprop->dim_size(1),
+ in_backprop->dim_size(2), row_stride, col_stride);
}
};
@@ -268,11 +256,10 @@ class Conv2DFastBackpropInputOp : public OpKernel {
}
#endif
- LaunchBackwardInputConvolution<Device, T>()(
- context, context->eigen_device<Device>(), in_backprop->tensor<T, 4>(),
- filter.tensor<T, 4>(), out_backprop.tensor<T, 4>(),
- dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
- dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, data_format_);
+ LaunchConv2DBackpropInputOp<Device, T>()(
+ context, false, false, out_backprop, filter,
+ dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_,
+ in_backprop, data_format_);
}
private:
@@ -600,10 +587,6 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmType;
- using perftools::gputools::dnn::ProfileResult;
- using perftools::gputools::dnn::kDefaultAlgorithm;
const Tensor& input_sizes = context->input(0);
const Tensor& filter = context->input(1);
const Tensor& out_backprop = context->input(2);
@@ -615,351 +598,372 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
TensorShape input_shape;
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
input_sizes.vec<int32>(), &input_shape));
- const TensorShape& filter_shape = filter.shape();
-
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(
- context, ConvBackpropComputeDimensions(
- "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2,
- input_shape, filter_shape, out_backprop.shape(), strides_,
- padding_, data_format_, &dims));
Tensor* in_backprop = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
- const int padding_rows =
- (padding_ == VALID)
- ? 0
- : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
- dims.spatial_dims[0].stride +
- dims.spatial_dims[0].filter_size -
- dims.spatial_dims[0].input_size);
- const int padding_cols =
- (padding_ == VALID)
- ? 0
- : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
- dims.spatial_dims[1].stride +
- dims.spatial_dims[1].filter_size -
- dims.spatial_dims[1].input_size);
-
- // TODO(keveman): cuDNN only supports equal padding on both sides, so only
- // calling it when that is true. Remove this check when (if?) cuDNN starts
- // supporting different padding.
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
-
- auto* stream = context->op_device_context()->stream();
- OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
-
- if (!use_cudnn_) {
- context->SetStatus(errors::Unimplemented(
- "Conv2DBackpropInput for GPU is not currently supported "
- "without cudnn"));
- return;
- }
+ // For now we take the stride from the second and third dimensions only (we
+ // do not support striding on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
- if (dims.spatial_dims[0].filter_size == 1 &&
- dims.spatial_dims[1].filter_size == 1 &&
- dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
- data_format_ == FORMAT_NHWC) {
- // 1x1 filter, so call cublas directly.
- const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
- dims.spatial_dims[1].input_size;
- const uint64 k = dims.out_depth;
- const uint64 n = dims.in_depth;
-
- auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
- out_backprop.template flat<T>().size());
- auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
- filter.template flat<T>().size());
- auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
- in_backprop->template flat<T>().size());
-
- auto transpose = perftools::gputools::blas::Transpose::kTranspose;
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
-
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
- a_ptr, k, 0.0f, &c_ptr, n)
- .ok();
- if (!blas_launch_status) {
- context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
- ", n=", n, ", k=", k));
- }
- return;
- } else if (dims.spatial_dims[0].filter_size ==
- dims.spatial_dims[0].input_size &&
- dims.spatial_dims[1].filter_size ==
- dims.spatial_dims[1].input_size &&
- padding_ == VALID && data_format_ == FORMAT_NHWC) {
- // The input data and filter have the same height/width, so call cublas
- // directly.
- const uint64 m = dims.batch_size;
- const uint64 k = dims.out_depth;
- const uint64 n = dims.spatial_dims[0].input_size *
- dims.spatial_dims[1].input_size * dims.in_depth;
-
- auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
- out_backprop.template flat<T>().size());
- auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
- filter.template flat<T>().size());
- auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
- in_backprop->template flat<T>().size());
-
- auto transpose = perftools::gputools::blas::Transpose::kTranspose;
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
-
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
- a_ptr, k, 0.0f, &c_ptr, n)
- .ok();
- if (!blas_launch_status) {
- context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
- ", n=", n, ", k=", k));
- }
- return;
- }
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
+ stride_rows, stride_cols, padding_, in_backprop, data_format_);
+ }
- TensorShape compatible_input_shape;
- if (rows_odd || cols_odd) {
- // If a padding dimension is odd, we have one more element on the right
- // side or the bottom side. This is unsupported in cudnn. Therefore,
- // we pad that extra element and make it compatible.
- compatible_input_shape = ShapeFromFormat(
- data_format_, dims.batch_size,
- dims.spatial_dims[0].input_size + rows_odd,
- dims.spatial_dims[1].input_size + cols_odd, dims.in_depth);
- } else {
- compatible_input_shape = input_shape;
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ bool use_cudnn_;
+ TensorFormat data_format_;
+ LaunchConv2DBackpropInputOp<Device, T> launcher_;
+ bool cudnn_use_autotune_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
+};
+
+template <typename T>
+void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
+ OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& out_backprop, const Tensor& filter, int row_stride,
+ int col_stride, const Padding& padding, Tensor* in_backprop,
+ TensorFormat data_format) {
+ using perftools::gputools::dnn::AlgorithmConfig;
+ using perftools::gputools::dnn::AlgorithmType;
+ using perftools::gputools::dnn::ProfileResult;
+
+ std::vector<int32> strides(4, 1);
+ strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
+ strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
+ TensorShape input_shape = in_backprop->shape();
+
+ const TensorShape& filter_shape = filter.shape();
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
+ "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2,
+ input_shape, filter_shape, out_backprop.shape(),
+ strides, padding, data_format, &dims));
+
+ const int padding_rows =
+ (padding == VALID)
+ ? 0
+ : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
+ dims.spatial_dims[0].stride +
+ dims.spatial_dims[0].filter_size -
+ dims.spatial_dims[0].input_size);
+ const int padding_cols =
+ (padding == VALID)
+ ? 0
+ : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
+ dims.spatial_dims[1].stride +
+ dims.spatial_dims[1].filter_size -
+ dims.spatial_dims[1].input_size);
+
+ // TODO(keveman): cuDNN only supports equal padding on both sides, so only
+ // calling it when that is true. Remove this check when (if?) cuDNN starts
+ // supporting different padding.
+ bool rows_odd = (padding_rows % 2 != 0);
+ bool cols_odd = (padding_cols % 2 != 0);
+
+ auto* stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+
+ if (!use_cudnn) {
+ ctx->SetStatus(errors::Unimplemented(
+ "Conv2DBackpropInput for GPU is not currently supported "
+ "without cudnn"));
+ return;
+ }
+
+ if (dims.spatial_dims[0].filter_size == 1 &&
+ dims.spatial_dims[1].filter_size == 1 &&
+ dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
+ data_format == FORMAT_NHWC) {
+ // 1x1 filter, so call cublas directly.
+ const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.in_depth;
+
+ auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
+ filter.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
+ in_backprop->template flat<T>().size());
+
+ auto transpose = perftools::gputools::blas::Transpose::kTranspose;
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
+ a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
}
+ return;
+ } else if (dims.spatial_dims[0].filter_size ==
+ dims.spatial_dims[0].input_size &&
+ dims.spatial_dims[1].filter_size ==
+ dims.spatial_dims[1].input_size &&
+ padding == VALID && data_format == FORMAT_NHWC) {
+ // The input data and filter have the same height/width, so call cublas
+ // directly.
+ const uint64 m = dims.batch_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size * dims.in_depth;
+
+ auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
+ filter.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
+ in_backprop->template flat<T>().size());
+
+ auto transpose = perftools::gputools::blas::Transpose::kTranspose;
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
+ a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
+ }
+ return;
+ }
- CHECK(padding_rows >= 0 && padding_cols >= 0)
- << "Negative row or col paddings: (" << padding_rows << ", "
- << padding_cols << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc;
- input_desc.set_count(dims.batch_size)
- .set_height(GetTensorDim(compatible_input_shape, data_format_, 'H'))
- .set_width(GetTensorDim(compatible_input_shape, data_format_, 'W'))
- .set_feature_map_count(dims.in_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
- output_desc.set_count(dims.batch_size)
- .set_height(dims.spatial_dims[0].output_size)
- .set_width(dims.spatial_dims[1].output_size)
- .set_feature_map_count(dims.out_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
- .set_input_filter_width(dims.spatial_dims[1].filter_size)
- .set_input_feature_map_count(dims.in_depth)
- .set_output_feature_map_count(dims.out_depth);
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
- .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
- .set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
-
- // NOTE(keveman):
- // cuDNN only supports the following layouts :
- // Input : B x D x R x C
- // Filter : OD x ID x R x C
- // Whereas, we have
- // Input : B x R x C x D
- // Filter : R x C x ID x OD
- // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
- // The first TransformDepth performs
- // (B x R x C x D) => (B x D x R x C).
- // Since the tensor returned from cuDNN is B x D x R x C also,
- // the second TransformDepth performs
- // (B x D x R x C) => (B x R x C x D).
- Tensor transformed_filter;
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape({dims.out_depth, dims.in_depth,
- dims.spatial_dims[0].filter_size,
- dims.spatial_dims[1].filter_size}),
- &transformed_filter));
-
- functor::TransformFilter<Device, T, int, 4>()(
- context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()),
- To32Bit(transformed_filter.tensor<T, 4>()));
-
- Tensor transformed_out_backprop;
- if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = ShapeFromFormat(
- FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
- dims.spatial_dims[1].output_size, dims.out_depth);
- if (dims.out_depth > 1) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value, nchw_shape,
- &transformed_out_backprop));
- functor::NHWCToNCHW<Device, T, 4>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
- transformed_out_backprop.tensor<T, 4>());
- } else {
- // If depth <= 1, then just reshape.
- CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
- }
+ TensorShape compatible_input_shape;
+ if (rows_odd || cols_odd) {
+ // If a padding dimension is odd, we have one more element on the right
+ // side or the bottom side. This is unsupported in cudnn. Therefore,
+ // we pad that extra element and make it compatible.
+ compatible_input_shape = ShapeFromFormat(
+ data_format, dims.batch_size,
+ dims.spatial_dims[0].input_size + rows_odd,
+ dims.spatial_dims[1].input_size + cols_odd, dims.in_depth);
+ } else {
+ compatible_input_shape = input_shape;
+ }
+
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(dims.batch_size)
+ .set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
+ .set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
+ .set_feature_map_count(dims.in_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(dims.batch_size)
+ .set_height(dims.spatial_dims[0].output_size)
+ .set_width(dims.spatial_dims[1].output_size)
+ .set_feature_map_count(dims.out_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
+ .set_input_filter_width(dims.spatial_dims[1].filter_size)
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
+ .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
+
+ // NOTE(keveman):
+ // cuDNN only supports the following layouts :
+ // Input : B x D x R x C
+ // Filter : OD x ID x R x C
+ // Whereas, we have
+ // Input : B x R x C x D
+ // Filter : R x C x ID x OD
+ // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
+ // The first TransformDepth performs
+ // (B x R x C x D) => (B x D x R x C).
+ // Since the tensor returned from cuDNN is B x D x R x C also,
+ // the second TransformDepth performs
+ // (B x D x R x C) => (B x R x C x D).
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[1].filter_size}),
+ &transformed_filter));
+
+ functor::TransformFilter<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ To32Bit(transformed_filter.tensor<T, 4>()));
+
+ Tensor transformed_out_backprop;
+ if (data_format == FORMAT_NHWC) {
+ TensorShape nchw_shape = ShapeFromFormat(
+ FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
+ dims.spatial_dims[1].output_size, dims.out_depth);
+ if (dims.out_depth > 1) {
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
+ &transformed_out_backprop));
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
+ transformed_out_backprop.tensor<T, 4>());
} else {
- transformed_out_backprop = out_backprop;
+ // If depth <= 1, then just reshape.
+ CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
}
+ } else {
+ transformed_out_backprop = out_backprop;
+ }
- Tensor pre_transformed_in_backprop;
- OP_REQUIRES_OK(
- context,
- context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(
- FORMAT_NCHW,
- GetTensorDim(compatible_input_shape, data_format_, 'N'),
- GetTensorDim(compatible_input_shape, data_format_, 'H'),
- GetTensorDim(compatible_input_shape, data_format_, 'W'),
- GetTensorDim(compatible_input_shape, data_format_, 'C')),
- &pre_transformed_in_backprop));
-
- auto out_backprop_ptr =
- AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
- transformed_out_backprop.template flat<T>().size());
- auto filter_ptr =
- AsDeviceMemory(transformed_filter.template flat<T>().data(),
- transformed_filter.template flat<T>().size());
- auto in_backprop_ptr =
- AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
- pre_transformed_in_backprop.template flat<T>().size());
-
- static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
- );
- CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
- context);
- int device_id = stream->parent()->device_ordinal();
- DataType dtype = out_backprop.dtype();
- ConvParameters conv_parameters = {
- dims.batch_size, // batch
- dims.in_depth, // in_depths
- {{input_desc.height(), // in_rows
- input_desc.width()}}, // in_cols
- dims.out_depth, // out_depths
- {{dims.spatial_dims[0].filter_size, // filter_rows
- dims.spatial_dims[1].filter_size}}, // filter_cols
- {{dims.spatial_dims[0].stride, // stride_rows
- dims.spatial_dims[1].stride}}, // stride_cols
- {{padding_rows, // padding_rows
- padding_cols}}, // padding_cols
- dtype, // tensor data type
- device_id, // device_id
- };
- AlgorithmConfig algorithm_config;
- if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find(
- conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
- ProfileResult best_result;
- ProfileResult best_result_no_scratch;
- for (auto profile_algorithm : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
- context);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardDataWithAlgorithm(
- filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ Tensor pre_transformed_in_backprop;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(
+ FORMAT_NCHW,
+ GetTensorDim(compatible_input_shape, data_format, 'N'),
+ GetTensorDim(compatible_input_shape, data_format, 'H'),
+ GetTensorDim(compatible_input_shape, data_format, 'W'),
+ GetTensorDim(compatible_input_shape, data_format, 'C')),
+ &pre_transformed_in_backprop));
+
+ auto out_backprop_ptr =
+ AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
+ transformed_out_backprop.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto in_backprop_ptr =
+ AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
+ pre_transformed_in_backprop.template flat<T>().size());
+
+ static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
+ );
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
+ int device_id = stream->parent()->device_ordinal();
+ DataType dtype = out_backprop.dtype();
+ ConvParameters conv_parameters = {
+ dims.batch_size, // batch
+ dims.in_depth, // in_depths
+ {{input_desc.height(), // in_rows
+ input_desc.width()}}, // in_cols
+ dims.out_depth, // out_depths
+ {{dims.spatial_dims[0].filter_size, // filter_rows
+ dims.spatial_dims[1].filter_size}}, // filter_cols
+ {{dims.spatial_dims[0].stride, // stride_rows
+ dims.spatial_dims[1].stride}}, // stride_cols
+ {{padding_rows, // padding_rows
+ padding_cols}}, // padding_cols
+ dtype, // tensor data type
+ device_id, // device_id
+ };
+ AlgorithmConfig algorithm_config;
+ if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
+ conv_parameters, &algorithm_config)) {
+ std::vector<AlgorithmType> algorithms;
+ CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
+ ProfileResult best_result;
+ ProfileResult best_result_no_scratch;
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
- OP_REQUIRES(context,
- best_result.is_valid() || best_result_no_scratch.is_valid(),
- errors::NotFound("No algorithm worked!"));
- if (best_result.is_valid()) {
- algorithm_config.set_algorithm(best_result.algorithm());
- }
- if (best_result_no_scratch.is_valid()) {
- algorithm_config.set_algorithm_no_scratch(
- best_result_no_scratch.algorithm());
- }
- AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
- algorithm_config);
- }
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardDataWithAlgorithm(
- filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
- algorithm_config, nullptr)
- .ok();
-
- if (!cudnn_launch_status) {
- context->SetStatus(errors::Internal(
- "cuDNN Backward Data function launch failure : input shape(",
- input_shape.DebugString(), ") filter shape(",
- filter_shape.DebugString(), ")"));
- return;
}
-
- if (rows_odd || cols_odd) {
- Tensor in_backprop_remove_padding;
- OP_REQUIRES_OK(
- context,
- context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(FORMAT_NCHW,
- GetTensorDim(input_shape, data_format_, 'N'),
- GetTensorDim(input_shape, data_format_, 'H'),
- GetTensorDim(input_shape, data_format_, 'W'),
- GetTensorDim(input_shape, data_format_, 'C')),
- &in_backprop_remove_padding));
-
- // Remove the padding for odd rows or cols.
- functor::PadInput<GPUDevice, T, int, 4>()(
- context->template eigen_device<GPUDevice>(),
- To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
- .tensor<T, 4>()),
- {{0, 0}}, {{-rows_odd, -cols_odd}},
- To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
-
- pre_transformed_in_backprop = in_backprop_remove_padding;
+ OP_REQUIRES(ctx,
+ best_result.is_valid() || best_result_no_scratch.is_valid(),
+ errors::NotFound("No algorithm worked!"));
+ if (best_result.is_valid()) {
+ algorithm_config.set_algorithm(best_result.algorithm());
}
-
- if (data_format_ == FORMAT_NHWC) {
- auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
- functor::NCHWToNHWC<Device, T, 4>()(
- context->eigen_device<Device>(),
- toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
- in_backprop->tensor<T, 4>());
- } else {
- *in_backprop = pre_transformed_in_backprop;
+ if (best_result_no_scratch.is_valid()) {
+ algorithm_config.set_algorithm_no_scratch(
+ best_result_no_scratch.algorithm());
}
+ AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
+ algorithm_config);
+ }
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
+ input_desc, &in_backprop_ptr, &scratch_allocator,
+ algorithm_config, nullptr)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "cuDNN Backward Data function launch failure : input shape(",
+ input_shape.DebugString(), ") filter shape(",
+ filter_shape.DebugString(), ")"));
+ return;
}
- private:
- std::vector<int32> strides_;
- Padding padding_;
- bool use_cudnn_;
- TensorFormat data_format_;
- bool cudnn_use_autotune_;
+ if (rows_odd || cols_odd) {
+ Tensor in_backprop_remove_padding;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(FORMAT_NCHW,
+ GetTensorDim(input_shape, data_format, 'N'),
+ GetTensorDim(input_shape, data_format, 'H'),
+ GetTensorDim(input_shape, data_format, 'W'),
+ GetTensorDim(input_shape, data_format, 'C')),
+ &in_backprop_remove_padding));
+
+ // Remove the padding for odd rows or cols.
+ functor::PadInput<GPUDevice, T, int, 4>()(
+ ctx->template eigen_device<GPUDevice>(),
+ To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
+ .tensor<T, 4>()),
+ {{0, 0}}, {{-rows_odd, -cols_odd}},
+ To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
+
+ pre_transformed_in_backprop = in_backprop_remove_padding;
+ }
- TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
-};
+ if (data_format == FORMAT_NHWC) {
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::NCHWToNHWC<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
+ in_backprop->tensor<T, 4>());
+ } else {
+ *in_backprop = pre_transformed_in_backprop;
+ }
+}
// Forward declarations of the functor specializations for GPU.
namespace functor {
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index 3ea9510afb..2926bb3a86 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -168,6 +168,43 @@ limitations under the License.
namespace tensorflow {
+// Forward declaration.
+class OpKernelContext;
+
+template <typename Device, typename T>
+struct LaunchConv2DBackpropInputOp {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& out_backprop, const Tensor& filter,
+ int row_stride, int col_stride, const Padding& padding,
+ Tensor* in_backprop, TensorFormat data_format);
+};
+
+template <typename Device, typename T>
+struct LaunchConv2DBackpropFilterOp {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& out_backprop, const Tensor& input,
+ int row_stride, int col_stride, const Padding& padding,
+ Tensor* filter_backprop, TensorFormat data_format);
+};
+
+#ifdef GOOGLE_CUDA
+template <typename T>
+struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Padding& padding, Tensor* output,
+ TensorFormat data_format);
+};
+
+template <typename T>
+struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& out_backprop, const Tensor& input,
+ int row_stride, int col_stride, const Padding& padding,
+ Tensor* filter_backprop, TensorFormat data_format);
+};
+#endif // GOOGLE_CUDA
+
// Information about a single spatial dimension for a convolution
// backpropagation.
struct ConvBackpropSpatialDimension {
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 2c77a38952..9de8642d41 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -58,10 +58,10 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename Device, typename T>
struct LaunchGeneric {
- static void launch(OpKernelContext* ctx, const Tensor& input,
- const Tensor& filter, int row_stride, int col_stride,
- const Eigen::PaddingType& padding, Tensor* output,
- TensorFormat data_format) {
+ void operator()(OpKernelContext* ctx, const Tensor& input,
+ const Tensor& filter, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
+ TensorFormat data_format) {
CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
"supports NHWC tensor format for now.";
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
@@ -86,8 +86,7 @@ struct LaunchGeneric {
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
dim_pair);
} else if (filter.dim_size(0) == input.dim_size(1) &&
- filter.dim_size(1) == input.dim_size(2) &&
- padding == Eigen::PADDING_VALID) {
+ filter.dim_size(1) == input.dim_size(2) && padding == VALID) {
// If the input data and filter have the same height/width,
// the 2D convolution is reduced to matrix multiplication.
const int k = // Length of reduction dimension.
@@ -104,28 +103,26 @@ struct LaunchGeneric {
functor::SpatialConvolution<Device, T>()(
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
- padding);
+ BrainPadding2EigenPadding(padding));
}
}
};
} // namespace
template <typename T>
-class LaunchConv2DOp<CPUDevice, T> {
- public:
- void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Eigen::PaddingType& padding, Tensor* output,
- TensorFormat data_format) {
+struct LaunchConv2DOp<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Padding& padding, Tensor* output,
+ TensorFormat data_format) {
if (data_format != FORMAT_NHWC) {
ctx->SetStatus(
errors::Unimplemented("Generic conv implementation only supports "
"NHWC tensor format for now."));
return;
}
- LaunchGeneric<CPUDevice, T>::launch(ctx, input, filter, row_stride,
- col_stride, padding, output,
- data_format);
+ LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
+ padding, output, data_format);
}
};
@@ -387,9 +384,8 @@ class Conv2DOp : public BinaryOp<T> {
return;
}
- launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- stride_rows, stride_cols,
- BrainPadding2EigenPadding(padding_), output, data_format_);
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
+ stride_rows, stride_cols, padding_, output, data_format_);
}
private:
@@ -445,10 +441,10 @@ typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
AutoTuneConv;
template <typename T>
-void LaunchConv2DOp<GPUDevice, T>::launch(
+void LaunchConv2DOp<GPUDevice, T>::operator()(
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& input_param, const Tensor& filter, int row_stride,
- int col_stride, const Eigen::PaddingType& padding, Tensor* output,
+ int col_stride, const Padding& padding, Tensor* output,
TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::AlgorithmType;
@@ -492,8 +488,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch(
}
return;
} else if (filter.dim_size(0) == input.dim_size(1) &&
- filter.dim_size(1) == input.dim_size(2) &&
- padding == Eigen::PADDING_VALID && data_format == FORMAT_NHWC) {
+ filter.dim_size(1) == input.dim_size(2) && padding == VALID &&
+ data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
const uint64 m = input.dim_size(0);
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index 60091fc27f..e29271dff2 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -32,14 +32,23 @@ namespace tensorflow {
class OpKernelContext;
template <typename Device, typename T>
-class LaunchConv2DOp {
- public:
- void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Eigen::PaddingType& padding, Tensor* output,
- TensorFormat data_format);
+struct LaunchConv2DOp {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Padding& padding, Tensor* output,
+ TensorFormat data_format);
};
+#ifdef GOOGLE_CUDA
+template <typename T>
+struct LaunchConv2DOp<Eigen::GpuDevice, T> {
+ void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Padding& padding, Tensor* output,
+ TensorFormat data_format);
+};
+#endif // GOOGLE_CUDA
+
// Used to keep track of persistent memory buffers used within the op.
// It uses malloc and free to avoid the time cost of initializing the memory.
template <class T, size_t size>
@@ -55,17 +64,6 @@ struct Im2ColBufferResource : public ResourceBase {
string DebugString() { return "Im2ColBufferResource"; }
};
-#ifdef GOOGLE_CUDA
-template <typename T>
-class LaunchConv2DOp<Eigen::GpuDevice, T> {
- public:
- void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Eigen::PaddingType& padding, Tensor* output,
- TensorFormat data_format);
-};
-#endif // GOOGLE_CUDA
-
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_CONV_OPS_H
diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
index 00d7f56408..9804d7d38e 100644
--- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
@@ -361,19 +361,15 @@ static void ComputeBackpropInput(const DepthwiseArgs& args,
}
}
-// Kernels to compute the input backprop for depthwise convolution.
-template <typename Device, typename T>
-struct LaunchDepthwiseConvBackpropInputOp;
-
// Computes the depthwise conv2d backprop input of 'out_backprop' by
// 'depthwise_filter' and stores the result in 'in_backprop'.
template <typename T>
struct LaunchDepthwiseConvBackpropInputOp<CPUDevice, T> {
typedef typename Eigen::internal::packet_traits<T>::type Packet;
- static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
- const T* out_backprop, const T* depthwise_filter,
- T* in_backprop, TensorFormat data_format) {
+ void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* out_backprop, const T* depthwise_filter,
+ T* in_backprop, TensorFormat data_format) {
OP_REQUIRES(
ctx, data_format == FORMAT_NHWC,
errors::Unimplemented(
@@ -514,27 +510,8 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
#if GOOGLE_CUDA
-template <typename T>
-struct DepthwiseConv2dBackpropInputGPULaunch {
- static void Run(const GPUDevice& d, const DepthwiseArgs args,
- const T* out_backprop, const T* filter, T* in_backprop,
- TensorFormat data_format);
-};
-
-template <typename T>
-struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, T> {
- static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
- const T* out_backprop, const T* filter, T* in_backprop,
- TensorFormat data_format) {
- const GPUDevice& d = ctx->eigen_device<GPUDevice>();
- DepthwiseConv2dBackpropInputGPULaunch<T>().Run(
- d, args, out_backprop, filter, in_backprop, data_format);
- auto stream = ctx->op_device_context()->stream();
- OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for "
- "DepthwiseConv2dBackpropInp"
- "utGPULaunch failed"));
- }
-};
+extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
+extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
#endif // GOOGLE_CUDA
@@ -598,7 +575,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
if (input_shape.num_elements() == 0) {
return;
}
- LaunchDepthwiseConvBackpropInputOp<Device, T>::launch(
+ LaunchDepthwiseConvBackpropInputOp<Device, T>()(
context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
data_format_);
}
@@ -744,9 +721,9 @@ template <typename T>
struct LaunchDepthwiseConvBackpropFilterOp<CPUDevice, T> {
typedef typename Eigen::internal::packet_traits<T>::type Packet;
- static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
- const T* out_backprop, const T* input, T* filter_backprop,
- TensorFormat data_format) {
+ void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* out_backprop, const T* input, T* filter_backprop,
+ TensorFormat data_format) {
OP_REQUIRES(
ctx, data_format == FORMAT_NHWC,
errors::Unimplemented(
@@ -907,35 +884,8 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
#if GOOGLE_CUDA
-template <typename T>
-struct DepthwiseConv2dBackpropFilterGPULaunch {
- static void Run(const GPUDevice& d, const DepthwiseArgs args,
- const T* out_backprop, const T* input, T* filter_backprop,
- TensorFormat data_format);
-};
-
-template <typename T>
-struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T> {
- static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
- const T* out_backprop, const T* input, T* filter_backprop,
- TensorFormat data_format) {
- const GPUDevice& d = ctx->eigen_device<GPUDevice>();
- auto stream = ctx->op_device_context()->stream();
-
- // Initialize the results to 0.
- int num_filter_backprop =
- args.filter_rows * args.filter_cols * args.out_depth;
- perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop,
- num_filter_backprop);
- stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
-
- DepthwiseConv2dBackpropFilterGPULaunch<T>().Run(
- d, args, out_backprop, input, filter_backprop, data_format);
- OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for "
- "DepthwiseConv2dBackpropFil"
- "terGPULaunch failed"));
- }
-};
+extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
+extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;
#endif // GOOGLE_CUDA
@@ -1001,7 +951,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
if (filter_shape.num_elements() == 0) {
return;
}
- LaunchDepthwiseConvBackpropFilterOp<Device, T>::launch(
+ LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
data_format_);
}
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc
index 3c01546d8d..bbeeaf7895 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op.cc
@@ -54,9 +54,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device, typename T>
-struct LaunchDepthwiseConvOp;
-
// Computes the vectorized product of 'input_buffer' and 'filter' and stores
// result in 'output' at location specified by 'out_r' and 'out_c'.
//
@@ -156,9 +153,9 @@ template <typename T>
struct LaunchDepthwiseConvOp<CPUDevice, T> {
typedef typename Eigen::internal::packet_traits<T>::type Packet;
- static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
- const T* input, const T* depthwise_filter, T* output,
- TensorFormat data_format) {
+ void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* input, const T* depthwise_filter, T* output,
+ TensorFormat data_format) {
OP_REQUIRES(
ctx, data_format == FORMAT_NHWC,
errors::Unimplemented(
@@ -248,27 +245,9 @@ extern template class LaunchConv2DOp<CPUDevice, float>;
#if GOOGLE_CUDA
-template <typename T>
-struct DepthwiseConv2dGPULaunch {
- static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input,
- const T* filter, T* output, TensorFormat data_format);
-};
-
-template <typename T>
-struct LaunchDepthwiseConvOp<GPUDevice, T> {
- static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
- const T* input, const T* filter, T* output,
- TensorFormat data_format) {
- const GPUDevice& d = ctx->eigen_device<GPUDevice>();
- DepthwiseConv2dGPULaunch<T>().Run(d, args, input, filter, output,
- data_format);
- auto stream = ctx->op_device_context()->stream();
- OP_REQUIRES(
- ctx, stream->ok(),
- errors::Internal(
- "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed"));
- }
-};
+// Extern template instantiated in depthwise_conv_op_gpu.cc.
+extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
+extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
// Extern template instantiated in conv_ops.cc.
extern template class LaunchConv2DOp<GPUDevice, float>;
@@ -393,9 +372,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
// If in_depth==1, this operation is just a standard convolution, so
// invoke that op.
if (std::is_same<T, float>::value && in_depth == 1) {
- launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- stride_, stride_, BrainPadding2EigenPadding(padding_),
- output, data_format_);
+ launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
+ stride_, stride_, padding_, output, data_format_);
return;
}
@@ -417,8 +395,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
auto input_ptr = input.template flat<T>().data();
auto filter_ptr = filter.template flat<T>().data();
auto output_ptr = output->template flat<T>().data();
- LaunchDepthwiseConvOp<Device, T>::launch(
- context, args, input_ptr, filter_ptr, output_ptr, data_format_);
+ LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr,
+ output_ptr, data_format_);
}
private:
diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h
index 1960b02bbe..aa5b5c76f6 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.h
+++ b/tensorflow/core/kernels/depthwise_conv_op.h
@@ -56,6 +56,53 @@ struct DepthwiseArgs {
out_depth(0) {}
};
+// Forward declaration.
+class OpKernelContext;
+
+template <typename Device, typename T>
+struct LaunchDepthwiseConvOp {
+ void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* input, const T* filter, T* output,
+ TensorFormat data_format);
+};
+
+template <typename Device, typename T>
+struct LaunchDepthwiseConvBackpropInputOp {
+ void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* out_backprop, const T* filter, T* in_backprop,
+ TensorFormat data_format);
+};
+
+template <typename Device, typename T>
+struct LaunchDepthwiseConvBackpropFilterOp {
+ void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* out_backprop, const T* input, T* filter_backprop,
+ TensorFormat data_format);
+};
+
+#if GOOGLE_CUDA
+template <typename T>
+struct LaunchDepthwiseConvOp<Eigen::GpuDevice, T> {
+ void operator()(OpKernelContext* ctx, const DepthwiseArgs args,
+ const T* input, const T* filter, T* output,
+ TensorFormat data_format);
+};
+
+template <typename T>
+struct LaunchDepthwiseConvBackpropInputOp<Eigen::GpuDevice, T> {
+ void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* out_backprop, const T* filter, T* in_backprop,
+ TensorFormat data_format);
+};
+
+template <typename T>
+struct LaunchDepthwiseConvBackpropFilterOp<Eigen::GpuDevice, T> {
+ void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args,
+ const T* out_backprop, const T* input, T* filter_backprop,
+ TensorFormat data_format);
+};
+#endif
+
} // namespace tensorflow
namespace tensorflow {
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index f63a99a730..fcfcd188d2 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,6 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
@@ -689,21 +690,27 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
-struct DepthwiseConv2dGPULaunch {
- static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input,
- const T* filter, T* output, TensorFormat data_format) {
- if (args.filter_rows == 3 && args.filter_cols == 3) {
- LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output,
+void LaunchDepthwiseConvOp<GPUDevice, T>::operator()(OpKernelContext* ctx,
+ const DepthwiseArgs args,
+ const T* input,
+ const T* filter, T* output,
+ TensorFormat data_format) {
+ const GPUDevice& d = ctx->eigen_device<GPUDevice>();
+ if (args.filter_rows == 3 && args.filter_cols == 3) {
+ LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output,
+ data_format);
+ } else {
+ LaunchDepthwiseConv2dGPU<T, -1, -1>(d, args, input, filter, output,
data_format);
- } else {
- LaunchDepthwiseConv2dGPU<T, -1, -1>(d, args, input, filter, output,
- data_format);
- }
}
-};
+ auto stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream->ok(),
+ errors::Internal(
+ "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed"));
+}
-template struct DepthwiseConv2dGPULaunch<float>;
-template struct DepthwiseConv2dGPULaunch<double>;
+template struct LaunchDepthwiseConvOp<GPUDevice, float>;
+template struct LaunchDepthwiseConvOp<GPUDevice, double>;
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
@@ -893,22 +900,26 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
-struct DepthwiseConv2dBackpropInputGPULaunch {
- static void Run(const GpuDevice& d, const DepthwiseArgs args,
- const T* out_backprop, const T* filter, T* in_backprop,
- TensorFormat data_format) {
- if (args.filter_rows == 3 && args.filter_cols == 3) {
- LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
- d, args, out_backprop, filter, in_backprop, data_format);
- } else {
- LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
- d, args, out_backprop, filter, in_backprop, data_format);
- }
+void LaunchDepthwiseConvBackpropInputOp<GPUDevice, T>::operator()(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
+ const T* filter, T* in_backprop, TensorFormat data_format) {
+ const GPUDevice& d = ctx->eigen_device<GPUDevice>();
+ if (args.filter_rows == 3 && args.filter_cols == 3) {
+ LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
+ d, args, out_backprop, filter, in_backprop, data_format);
+ } else {
+ LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
+ d, args, out_backprop, filter, in_backprop, data_format);
}
-};
+ auto stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream->ok(),
+ errors::Internal("Launch of gpu kernel for "
+ "DepthwiseConv2dBackpropInp"
+ "utGPULaunch failed"));
+}
-template struct DepthwiseConv2dBackpropInputGPULaunch<float>;
-template struct DepthwiseConv2dBackpropInputGPULaunch<double>;
+template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
+template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
@@ -1580,21 +1591,33 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
-struct DepthwiseConv2dBackpropFilterGPULaunch {
- static void Run(const GpuDevice& d, const DepthwiseArgs args,
- const T* out_backprop, const T* input, T* filter_backprop,
- TensorFormat data_format) {
- if (args.filter_rows == 3 && args.filter_cols == 3) {
- LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
- d, args, out_backprop, input, filter_backprop, data_format);
- } else {
- LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
- d, args, out_backprop, input, filter_backprop, data_format);
- }
+void LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T>::operator()(
+ OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
+ const T* input, T* filter_backprop, TensorFormat data_format) {
+ const GPUDevice& d = ctx->eigen_device<GPUDevice>();
+ auto stream = ctx->op_device_context()->stream();
+
+ // Initialize the results to 0.
+ int num_filter_backprop =
+ args.filter_rows * args.filter_cols * args.out_depth;
+ perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop,
+ num_filter_backprop);
+ stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
+
+ if (args.filter_rows == 3 && args.filter_cols == 3) {
+ LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
+ d, args, out_backprop, input, filter_backprop, data_format);
+ } else {
+ LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
+ d, args, out_backprop, input, filter_backprop, data_format);
}
-};
+ OP_REQUIRES(ctx, stream->ok(),
+ errors::Internal("Launch of gpu kernel for "
+ "DepthwiseConv2dBackpropFil"
+ "terGPULaunch failed"));
+}
-template struct DepthwiseConv2dBackpropFilterGPULaunch<float>;
-template struct DepthwiseConv2dBackpropFilterGPULaunch<double>;
+template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
+template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;
} // namespace tensorflow
#endif // GOOGLE_CUDA