diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-04 03:19:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-04 03:24:08 -0700 |
commit | 07356b48e4b374efd406fd142faa77cfa4db05e9 (patch) | |
tree | f5049f7ef36486535e386934f3dfc48f72831f45 | |
parent | 0302320e11c7561cafac1cc279fea87de02b0cf9 (diff) |
Exposing launchpad for conv2d backprop, and unify launchpads for conv2d and depthwise_conv to match example in documentation (see ./extend/adding_an_op.md)
PiperOrigin-RevId: 167480081
-rw-r--r-- | tensorflow/core/kernels/conv_grad_filter_ops.cc | 690 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_grad_input_ops.cc | 716 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_grad_ops.h | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops.cc | 42 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops.h | 32 | ||||
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_grad_op.cc | 74 | ||||
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op.cc | 42 | ||||
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op.h | 47 | ||||
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc | 103 |
9 files changed, 922 insertions, 861 deletions
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 65514937f4..8eb705b2e5 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -91,6 +91,20 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +template <typename T> +struct LaunchConv2DBackpropInputOp<CPUDevice, T> { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_stride, int col_stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format) { + const CPUDevice& d = ctx->eigen_device<CPUDevice>(); + functor::SpatialConvolutionBackwardInput<CPUDevice, T>()( + d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(), + out_backprop.tensor<T, 4>(), filter_backprop->dim_size(0), + filter_backprop->dim_size(1), row_stride, col_stride); + } +}; + #ifdef TENSORFLOW_USE_LIBXSMM template <typename Device, class T> struct LaunchXsmmBackwardFilter { @@ -237,11 +251,9 @@ class Conv2DFastBackpropFilterOp : public OpKernel { } #endif - functor::SpatialConvolutionBackwardKernel<Device, T>()( - context->eigen_device<Device>(), filter_backprop->tensor<T, 4>(), - input.tensor<T, 4>(), out_backprop.tensor<T, 4>(), - dims.spatial_dims[0].filter_size, dims.spatial_dims[1].filter_size, - dims.spatial_dims[0].stride, dims.spatial_dims[1].stride); + LaunchConv2DBackpropInputOp<Device, T>()( + context, false, false, out_backprop, input, dims.spatial_dims[0].stride, + dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_); } private: @@ -495,15 +507,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); - cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization(); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } void Compute(OpKernelContext* context) override { - using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; - using perftools::gputools::dnn::ProfileResult; - using perftools::gputools::dnn::kDefaultAlgorithm; const Tensor& input = context->input(0); const Tensor& filter_sizes = context->input(1); const Tensor& out_backprop = context->input(2); @@ -512,352 +519,373 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { errors::InvalidArgument( "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", filter_sizes.dims())); - const TensorShape& input_shape = input.shape(); TensorShape filter_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( filter_sizes.vec<int32>(), &filter_shape)); - ConvBackpropDimensions dims; - OP_REQUIRES_OK(context, - ConvBackpropComputeDimensions( - "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2, - input.shape(), filter_shape, out_backprop.shape(), - strides_, padding_, data_format_, &dims)); - Tensor* filter_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); - const int padding_rows = - (padding_ == VALID) - ? 0 - : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * - dims.spatial_dims[0].stride + - dims.spatial_dims[0].filter_size - - dims.spatial_dims[0].input_size); - const int padding_cols = - (padding_ == VALID) - ? 0 - : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * - dims.spatial_dims[1].stride + - dims.spatial_dims[1].filter_size - - dims.spatial_dims[1].input_size); - - // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only - // calling it when that is true. Remove this check when (if?) cuDNN starts - // supporting different padding. - bool rows_odd = (padding_rows % 2 != 0); - bool cols_odd = (padding_cols % 2 != 0); - - auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - - if (!use_cudnn_) { - context->SetStatus(errors::Unimplemented( - "Conv2DBackprop for GPU is not currently supported " - "without cudnn")); - return; - } + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - if (!cudnn_disable_conv_1x1_optimization_ && - dims.spatial_dims[0].filter_size == 1 && - dims.spatial_dims[1].filter_size == 1 && - dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && - data_format_ == FORMAT_NHWC) { - const uint64 m = dims.in_depth; - const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size; - const uint64 n = dims.out_depth; - - // The shape of output backprop is - // [batch, out_rows, out_cols, out_depth] - // From cublas's perspective, it is: n x k - auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), - out_backprop.template flat<T>().size()); - - // The shape of input is - // [batch, in_rows, in_cols, in_depth], - // From cublas's perspective, it is: m x k - auto b_ptr = AsDeviceMemory(input.template flat<T>().data(), - input.template flat<T>().size()); - - // the shape of the filter backprop from the conv_2d should be - // [1, 1, in_depth, out_depth] - // From cublas's perspective, it is: n x m - auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), - filter_backprop->template flat<T>().size()); - - bool blas_launch_status = - stream - ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, - perftools::gputools::blas::Transpose::kTranspose, - n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } else if (dims.spatial_dims[0].filter_size == - dims.spatial_dims[0].input_size && - dims.spatial_dims[1].filter_size == - dims.spatial_dims[1].input_size && - padding_ == VALID && data_format_ == FORMAT_NHWC) { - // The input data and filter have the same height/width, so call cublas - // directly. - const uint64 m = dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size * dims.in_depth; - const uint64 k = dims.batch_size; - const uint64 n = dims.out_depth; - - auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), - input.template flat<T>().size()); - auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), - out_backprop.template flat<T>().size()); - auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), - filter_backprop->template flat<T>().size()); - - bool blas_launch_status = - stream - ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, - perftools::gputools::blas::Transpose::kTranspose, - n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } + launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input, + stride_rows, stride_cols, padding_, filter_backprop, + data_format_); + } - Tensor compatible_input; - if (rows_odd || cols_odd) { - // If a padding dimension is odd, we have one more element on the right - // side or the bottom side. This is unsupported in cudnn. Therefore, - // we pad that extra element and make it compatible. - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum<T>::value, - ShapeFromFormat(data_format_, dims.batch_size, - dims.spatial_dims[0].input_size + rows_odd, - dims.spatial_dims[1].input_size + cols_odd, - dims.in_depth), - &compatible_input)); - - functor::PadInput<GPUDevice, T, int, 4>()( - context->template eigen_device<GPUDevice>(), - To32Bit(input.tensor<T, 4>()), {{0, 0}}, {{rows_odd, cols_odd}}, - To32Bit(compatible_input.tensor<T, 4>()), data_format_); - } else { - compatible_input = input; + private: + std::vector<int32> strides_; + Padding padding_; + bool use_cudnn_; + TensorFormat data_format_; + LaunchConv2DBackpropFilterOp<Device, T> launcher_; + bool cudnn_use_autotune_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp); +}; + +template <typename T> +void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( + OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, int row_stride, + int col_stride, const Padding& padding, Tensor* filter_backprop, + TensorFormat data_format) { + using perftools::gputools::dnn::AlgorithmConfig; + using perftools::gputools::dnn::AlgorithmType; + using perftools::gputools::dnn::ProfileResult; + + std::vector<int32> strides(4, 1); + strides[GetTensorDimIndex(data_format, 'H')] = row_stride; + strides[GetTensorDimIndex(data_format, 'W')] = col_stride; + TensorShape filter_shape = filter_backprop->shape(); + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( + "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2, + input.shape(), filter_shape, out_backprop.shape(), + strides, padding, data_format, &dims)); + + const int padding_rows = + (padding == VALID) + ? 0 + : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * + dims.spatial_dims[0].stride + + dims.spatial_dims[0].filter_size - + dims.spatial_dims[0].input_size); + const int padding_cols = + (padding == VALID) + ? 0 + : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * + dims.spatial_dims[1].stride + + dims.spatial_dims[1].filter_size - + dims.spatial_dims[1].input_size); + + // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only + // calling it when that is true. Remove this check when (if?) cuDNN starts + // supporting different padding. + bool rows_odd = (padding_rows % 2 != 0); + bool cols_odd = (padding_cols % 2 != 0); + + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + if (!use_cudnn) { + ctx->SetStatus(errors::Unimplemented( + "Conv2DBackprop for GPU is not currently supported " + "without cudnn")); + return; + } + + bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization(); + if (!cudnn_disable_conv_1x1_optimization_ && + dims.spatial_dims[0].filter_size == 1 && + dims.spatial_dims[1].filter_size == 1 && + dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && + data_format == FORMAT_NHWC) { + const uint64 m = dims.in_depth; + const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size; + const uint64 n = dims.out_depth; + + // The shape of output backprop is + // [batch, out_rows, out_cols, out_depth] + // From cublas's perspective, it is: n x k + auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), + out_backprop.template flat<T>().size()); + + // The shape of input is + // [batch, in_rows, in_cols, in_depth], + // From cublas's perspective, it is: m x k + auto b_ptr = AsDeviceMemory(input.template flat<T>().data(), + input.template flat<T>().size()); + + // the shape of the filter backprop from the conv_2d should be + // [1, 1, in_depth, out_depth] + // From cublas's perspective, it is: n x m + auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), + filter_backprop->template flat<T>().size()); + + bool blas_launch_status = + stream + ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose, n, + m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); } + return; + } else if (dims.spatial_dims[0].filter_size == + dims.spatial_dims[0].input_size && + dims.spatial_dims[1].filter_size == + dims.spatial_dims[1].input_size && + padding == VALID && data_format == FORMAT_NHWC) { + // The input data and filter have the same height/width, so call cublas + // directly. + const uint64 m = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * dims.in_depth; + const uint64 k = dims.batch_size; + const uint64 n = dims.out_depth; + + auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), + input.template flat<T>().size()); + auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), + out_backprop.template flat<T>().size()); + auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), + filter_backprop->template flat<T>().size()); + + bool blas_launch_status = + stream + ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose, n, + m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } - CHECK(padding_rows >= 0 && padding_cols >= 0) - << "Negative row or col paddings: (" << padding_rows << ", " - << padding_cols << ")"; - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(dims.batch_size) - .set_height(GetTensorDim(compatible_input, data_format_, 'H')) - .set_width(GetTensorDim(compatible_input, data_format_, 'W')) - .set_feature_map_count(dims.in_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(dims.batch_size) - .set_height(dims.spatial_dims[0].output_size) - .set_width(dims.spatial_dims[1].output_size) - .set_feature_map_count(dims.out_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) - .set_input_filter_width(dims.spatial_dims[1].filter_size) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) - .set_horizontal_filter_stride(dims.spatial_dims[1].stride) - .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); - - // NOTE(zhengxq): - // cuDNN only supports the following layouts : - // Input : B x D x R x C - // Filter : OD x ID x R x C - // Whereas, we have - // Input : B x R x C x D - // Filter : R x C x ID x OD - // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) - // The first TransformDepth performs - // (B x R x C x D) => (B x D x R x C). - // Since the tensor returned from cuDNN is B x D x R x C also, - // the second TransformDepth performs - // (B x D x R x C) => (B x R x C x D). - - Tensor pre_transformed_filter_backprop; - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<T>::value, - TensorShape({dims.out_depth, dims.in_depth, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[1].filter_size}), - &pre_transformed_filter_backprop)); - - Tensor transformed_out_backprop; - if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = ShapeFromFormat( - FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, - dims.spatial_dims[1].output_size, dims.out_depth); - if (dims.out_depth > 1) { - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<T>::value, nchw_shape, - &transformed_out_backprop)); - functor::NHWCToNCHW<Device, T, 4>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), - transformed_out_backprop.tensor<T, 4>()); - } else { - // If depth <= 1, just reshape. - CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); - } + Tensor compatible_input; + if (rows_odd || cols_odd) { + // If a padding dimension is odd, we have one more element on the right + // side or the bottom side. This is unsupported in cudnn. Therefore, + // we pad that extra element and make it compatible. + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum<T>::value, + ShapeFromFormat(data_format, dims.batch_size, + dims.spatial_dims[0].input_size + rows_odd, + dims.spatial_dims[1].input_size + cols_odd, + dims.in_depth), + &compatible_input)); + + functor::PadInput<GPUDevice, T, int, 4>()( + ctx->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 4>()), + {{0, 0}}, {{rows_odd, cols_odd}}, + To32Bit(compatible_input.tensor<T, 4>()), data_format); + } else { + compatible_input = input; + } + + CHECK(padding_rows >= 0 && padding_cols >= 0) + << "Negative row or col paddings: (" << padding_rows << ", " + << padding_cols << ")"; + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(dims.batch_size) + .set_height(GetTensorDim(compatible_input, data_format, 'H')) + .set_width(GetTensorDim(compatible_input, data_format, 'W')) + .set_feature_map_count(dims.in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(dims.batch_size) + .set_height(dims.spatial_dims[0].output_size) + .set_width(dims.spatial_dims[1].output_size) + .set_feature_map_count(dims.out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) + .set_input_filter_width(dims.spatial_dims[1].filter_size) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) + .set_horizontal_filter_stride(dims.spatial_dims[1].stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + // NOTE(zhengxq): + // cuDNN only supports the following layouts : + // Input : B x D x R x C + // Filter : OD x ID x R x C + // Whereas, we have + // Input : B x R x C x D + // Filter : R x C x ID x OD + // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) + // The first TransformDepth performs + // (B x R x C x D) => (B x D x R x C). + // Since the tensor returned from cuDNN is B x D x R x C also, + // the second TransformDepth performs + // (B x D x R x C) => (B x R x C x D). + + Tensor pre_transformed_filter_backprop; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[1].filter_size}), + &pre_transformed_filter_backprop)); + + Tensor transformed_out_backprop; + if (data_format == FORMAT_NHWC) { + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, + dims.spatial_dims[1].output_size, dims.out_depth); + if (dims.out_depth > 1) { + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, + &transformed_out_backprop)); + functor::NHWCToNCHW<GPUDevice, T, 4>()( + ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(), + transformed_out_backprop.tensor<T, 4>()); } else { - transformed_out_backprop = out_backprop; + // If depth <= 1, just reshape. + CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); } + } else { + transformed_out_backprop = out_backprop; + } - Tensor transformed_input; - if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = ShapeFromFormat( - FORMAT_NCHW, GetTensorDim(compatible_input, data_format_, 'N'), - GetTensorDim(compatible_input, data_format_, 'H'), - GetTensorDim(compatible_input, data_format_, 'W'), - GetTensorDim(compatible_input, data_format_, 'C')); - if (nchw_shape.dim_size(1) > 1) { - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, - nchw_shape, &transformed_input)); - functor::NHWCToNCHW<Device, T, 4>()( - context->eigen_device<Device>(), - const_cast<const Tensor&>(compatible_input).tensor<T, 4>(), - transformed_input.tensor<T, 4>()); - } else { - // If depth <= 1, just reshape. - CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); - } + Tensor transformed_input; + if (data_format == FORMAT_NHWC) { + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, GetTensorDim(compatible_input, data_format, 'N'), + GetTensorDim(compatible_input, data_format, 'H'), + GetTensorDim(compatible_input, data_format, 'W'), + GetTensorDim(compatible_input, data_format, 'C')); + if (nchw_shape.dim_size(1) > 1) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + nchw_shape, &transformed_input)); + functor::NHWCToNCHW<GPUDevice, T, 4>()( + ctx->eigen_device<GPUDevice>(), + const_cast<const Tensor&>(compatible_input).tensor<T, 4>(), + transformed_input.tensor<T, 4>()); } else { - transformed_input = compatible_input; + // If depth <= 1, just reshape. + CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); } + } else { + transformed_input = compatible_input; + } - auto out_backprop_ptr = - AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), - transformed_out_backprop.template flat<T>().size()); - auto filter_backprop_ptr = AsDeviceMemory( - pre_transformed_filter_backprop.template flat<T>().data(), - pre_transformed_filter_backprop.template flat<T>().size()); - auto input_ptr = - AsDeviceMemory(transformed_input.template flat<T>().data(), - transformed_input.template flat<T>().size()); - - static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( - "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default - ); - int device_id = stream->parent()->device_ordinal(); - DataType dtype = input.dtype(); - ConvParameters conv_parameters = { - dims.batch_size, // batch - dims.in_depth, // in_depths - {{input_desc.height(), // in_rows - input_desc.width()}}, // in_cols - dims.out_depth, // out_depths - {{dims.spatial_dims[0].filter_size, // filter_rows - dims.spatial_dims[1].filter_size}}, // filter_cols - {{dims.spatial_dims[0].stride, // stride_rows - dims.spatial_dims[1].stride}}, // stride_cols - {{padding_rows, // padding_rows - padding_cols}}, // padding_cols - dtype, // tensor datatype - device_id, // device_id - }; - AlgorithmConfig algorithm_config; - if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find( - conv_parameters, &algorithm_config)) { - std::vector<AlgorithmType> algorithms; - CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; - for (auto profile_algorithm : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - CudnnScratchAllocator scratch_allocator( - ConvolveBackwardFilterScratchSize, context); - ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_desc, input_ptr, output_desc, out_backprop_ptr, - conv_desc, filter_desc, &filter_backprop_ptr, - &scratch_allocator, AlgorithmConfig(profile_algorithm), - &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), + transformed_out_backprop.template flat<T>().size()); + auto filter_backprop_ptr = + AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(), + pre_transformed_filter_backprop.template flat<T>().size()); + auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(), + transformed_input.template flat<T>().size()); + + static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default + ); + int device_id = stream->parent()->device_ordinal(); + DataType dtype = input.dtype(); + ConvParameters conv_parameters = { + dims.batch_size, // batch + dims.in_depth, // in_depths + {{input_desc.height(), // in_rows + input_desc.width()}}, // in_cols + dims.out_depth, // out_depths + {{dims.spatial_dims[0].filter_size, // filter_rows + dims.spatial_dims[1].filter_size}}, // filter_cols + {{dims.spatial_dims[0].stride, // stride_rows + dims.spatial_dims[1].stride}}, // stride_cols + {{padding_rows, // padding_rows + padding_cols}}, // padding_cols + dtype, // tensor datatype + device_id, // device_id + }; + AlgorithmConfig algorithm_config; + if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find( + conv_parameters, &algorithm_config)) { + std::vector<AlgorithmType> algorithms; + CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); + ProfileResult best_result; + ProfileResult best_result_no_scratch; + for (auto profile_algorithm : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + ctx); + ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_desc, input_ptr, output_desc, out_backprop_ptr, + conv_desc, filter_desc, &filter_backprop_ptr, + &scratch_allocator, AlgorithmConfig(profile_algorithm), + &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; } } } - OP_REQUIRES(context, - best_result.is_valid() || best_result_no_scratch.is_valid(), - errors::NotFound("No algorithm worked!")); - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - if (best_result_no_scratch.is_valid()) { - algorithm_config.set_algorithm_no_scratch( - best_result_no_scratch.algorithm()); - } - AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters, - algorithm_config); } - CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, - context); - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, - filter_desc, &filter_backprop_ptr, &scratch_allocator, - algorithm_config, nullptr) - .ok(); - - if (!cudnn_launch_status) { - context->SetStatus(errors::Internal( - "cuDNN Backward Filter function launch failure : input shape(", - input_shape.DebugString(), ") filter shape(", - filter_shape.DebugString(), ")")); - return; + OP_REQUIRES(ctx, + best_result.is_valid() || best_result_no_scratch.is_valid(), + errors::NotFound("No algorithm worked!")); + if (best_result.is_valid()) { + algorithm_config.set_algorithm(best_result.algorithm()); } - - auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::ReverseTransformFilter<Device, T, 4>()( - context->eigen_device<Device>(), - toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(), - filter_backprop->tensor<T, 4>()); + if (best_result_no_scratch.is_valid()) { + algorithm_config.set_algorithm_no_scratch( + best_result_no_scratch.algorithm()); + } + AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters, + algorithm_config); + } + CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + ctx); + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, + filter_desc, &filter_backprop_ptr, &scratch_allocator, + algorithm_config, nullptr) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN Backward Filter function launch failure : input shape(", + input.shape().DebugString(), ") filter shape(", + filter_shape.DebugString(), ")")); + return; } - private: - std::vector<int32> strides_; - Padding padding_; - bool use_cudnn_; - TensorFormat data_format_; - bool cudnn_use_autotune_; - bool cudnn_disable_conv_1x1_optimization_; - - TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp); -}; + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::ReverseTransformFilter<GPUDevice, T, 4>()( + ctx->eigen_device<GPUDevice>(), + toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(), + filter_backprop->tensor<T, 4>()); +} // Forward declarations of the functor specializations for GPU. namespace functor { diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index a5a9549a2f..ce561aa99c 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -97,29 +97,17 @@ typedef Eigen::GpuDevice GPUDevice; // for CPU for now since nvcc times out when trying to compile them. // TODO(yangke): enable them for GPUs when we have a faster compiler. -template <typename Device, class T> -struct LaunchBackwardInputConvolution { - bool operator()(OpKernelContext* context, const Device&, - typename TTypes<T, 4>::Tensor, - typename TTypes<T, 4>::ConstTensor, - typename TTypes<T, 4>::ConstTensor, int, int, int, int, - TensorFormat) const { - return false; - } -}; - -template <> -struct LaunchBackwardInputConvolution<CPUDevice, float> { - bool operator()(OpKernelContext* context, const CPUDevice& d, - typename TTypes<float, 4>::Tensor input_backward, - typename TTypes<float, 4>::ConstTensor kernel, - typename TTypes<float, 4>::ConstTensor output_backward, - int input_rows, int input_cols, int row_stride, - int col_stride, TensorFormat data_format) const { - functor::SpatialConvolutionBackwardInput<CPUDevice, float>()( - d, input_backward, kernel, output_backward, input_rows, input_cols, - row_stride, col_stride); - return true; +template <typename T> +struct LaunchConv2DBackpropInputOp<CPUDevice, T> { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, + int row_stride, int col_stride, const Padding& padding, + Tensor* in_backprop, TensorFormat data_format) { + const CPUDevice& d = ctx->eigen_device<CPUDevice>(); + functor::SpatialConvolutionBackwardInput<CPUDevice, T>()( + d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(), + out_backprop.tensor<T, 4>(), in_backprop->dim_size(1), + in_backprop->dim_size(2), row_stride, col_stride); } }; @@ -268,11 +256,10 @@ class Conv2DFastBackpropInputOp : public OpKernel { } #endif - LaunchBackwardInputConvolution<Device, T>()( - context, context->eigen_device<Device>(), in_backprop->tensor<T, 4>(), - filter.tensor<T, 4>(), out_backprop.tensor<T, 4>(), - dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size, - dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, data_format_); + LaunchConv2DBackpropInputOp<Device, T>()( + context, false, false, out_backprop, filter, + dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_, + in_backprop, data_format_); } private: @@ -600,10 +587,6 @@ class Conv2DSlowBackpropInputOp : public OpKernel { } void Compute(OpKernelContext* context) override { - using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; - using perftools::gputools::dnn::ProfileResult; - using perftools::gputools::dnn::kDefaultAlgorithm; const Tensor& input_sizes = context->input(0); const Tensor& filter = context->input(1); const Tensor& out_backprop = context->input(2); @@ -615,351 +598,372 @@ class Conv2DSlowBackpropInputOp : public OpKernel { TensorShape input_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( input_sizes.vec<int32>(), &input_shape)); - const TensorShape& filter_shape = filter.shape(); - - ConvBackpropDimensions dims; - OP_REQUIRES_OK( - context, ConvBackpropComputeDimensions( - "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, - input_shape, filter_shape, out_backprop.shape(), strides_, - padding_, data_format_, &dims)); Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); - const int padding_rows = - (padding_ == VALID) - ? 0 - : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * - dims.spatial_dims[0].stride + - dims.spatial_dims[0].filter_size - - dims.spatial_dims[0].input_size); - const int padding_cols = - (padding_ == VALID) - ? 0 - : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * - dims.spatial_dims[1].stride + - dims.spatial_dims[1].filter_size - - dims.spatial_dims[1].input_size); - - // TODO(keveman): cuDNN only supports equal padding on both sides, so only - // calling it when that is true. Remove this check when (if?) cuDNN starts - // supporting different padding. - bool rows_odd = (padding_rows % 2 != 0); - bool cols_odd = (padding_cols % 2 != 0); - - auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - - if (!use_cudnn_) { - context->SetStatus(errors::Unimplemented( - "Conv2DBackpropInput for GPU is not currently supported " - "without cudnn")); - return; - } + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - if (dims.spatial_dims[0].filter_size == 1 && - dims.spatial_dims[1].filter_size == 1 && - dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && - data_format_ == FORMAT_NHWC) { - // 1x1 filter, so call cublas directly. - const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size; - const uint64 k = dims.out_depth; - const uint64 n = dims.in_depth; - - auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), - out_backprop.template flat<T>().size()); - auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), - filter.template flat<T>().size()); - auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), - in_backprop->template flat<T>().size()); - - auto transpose = perftools::gputools::blas::Transpose::kTranspose; - auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; - - bool blas_launch_status = - stream - ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, - a_ptr, k, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } else if (dims.spatial_dims[0].filter_size == - dims.spatial_dims[0].input_size && - dims.spatial_dims[1].filter_size == - dims.spatial_dims[1].input_size && - padding_ == VALID && data_format_ == FORMAT_NHWC) { - // The input data and filter have the same height/width, so call cublas - // directly. - const uint64 m = dims.batch_size; - const uint64 k = dims.out_depth; - const uint64 n = dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size * dims.in_depth; - - auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), - out_backprop.template flat<T>().size()); - auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), - filter.template flat<T>().size()); - auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), - in_backprop->template flat<T>().size()); - - auto transpose = perftools::gputools::blas::Transpose::kTranspose; - auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; - - bool blas_launch_status = - stream - ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, - a_ptr, k, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } + launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter, + stride_rows, stride_cols, padding_, in_backprop, data_format_); + } - TensorShape compatible_input_shape; - if (rows_odd || cols_odd) { - // If a padding dimension is odd, we have one more element on the right - // side or the bottom side. This is unsupported in cudnn. Therefore, - // we pad that extra element and make it compatible. - compatible_input_shape = ShapeFromFormat( - data_format_, dims.batch_size, - dims.spatial_dims[0].input_size + rows_odd, - dims.spatial_dims[1].input_size + cols_odd, dims.in_depth); - } else { - compatible_input_shape = input_shape; + private: + std::vector<int32> strides_; + Padding padding_; + bool use_cudnn_; + TensorFormat data_format_; + LaunchConv2DBackpropInputOp<Device, T> launcher_; + bool cudnn_use_autotune_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp); +}; + +template <typename T> +void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( + OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* in_backprop, + TensorFormat data_format) { + using perftools::gputools::dnn::AlgorithmConfig; + using perftools::gputools::dnn::AlgorithmType; + using perftools::gputools::dnn::ProfileResult; + + std::vector<int32> strides(4, 1); + strides[GetTensorDimIndex(data_format, 'H')] = row_stride; + strides[GetTensorDimIndex(data_format, 'W')] = col_stride; + TensorShape input_shape = in_backprop->shape(); + + const TensorShape& filter_shape = filter.shape(); + ConvBackpropDimensions dims; + OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( + "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, + input_shape, filter_shape, out_backprop.shape(), + strides, padding, data_format, &dims)); + + const int padding_rows = + (padding == VALID) + ? 0 + : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * + dims.spatial_dims[0].stride + + dims.spatial_dims[0].filter_size - + dims.spatial_dims[0].input_size); + const int padding_cols = + (padding == VALID) + ? 0 + : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * + dims.spatial_dims[1].stride + + dims.spatial_dims[1].filter_size - + dims.spatial_dims[1].input_size); + + // TODO(keveman): cuDNN only supports equal padding on both sides, so only + // calling it when that is true. Remove this check when (if?) cuDNN starts + // supporting different padding. + bool rows_odd = (padding_rows % 2 != 0); + bool cols_odd = (padding_cols % 2 != 0); + + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + if (!use_cudnn) { + ctx->SetStatus(errors::Unimplemented( + "Conv2DBackpropInput for GPU is not currently supported " + "without cudnn")); + return; + } + + if (dims.spatial_dims[0].filter_size == 1 && + dims.spatial_dims[1].filter_size == 1 && + dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && + data_format == FORMAT_NHWC) { + // 1x1 filter, so call cublas directly. + const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.in_depth; + + auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), + out_backprop.template flat<T>().size()); + auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), + filter.template flat<T>().size()); + auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), + in_backprop->template flat<T>().size()); + + auto transpose = perftools::gputools::blas::Transpose::kTranspose; + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + + bool blas_launch_status = + stream + ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); } + return; + } else if (dims.spatial_dims[0].filter_size == + dims.spatial_dims[0].input_size && + dims.spatial_dims[1].filter_size == + dims.spatial_dims[1].input_size && + padding == VALID && data_format == FORMAT_NHWC) { + // The input data and filter have the same height/width, so call cublas + // directly. + const uint64 m = dims.batch_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * dims.in_depth; + + auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), + out_backprop.template flat<T>().size()); + auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), + filter.template flat<T>().size()); + auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), + in_backprop->template flat<T>().size()); + + auto transpose = perftools::gputools::blas::Transpose::kTranspose; + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + + bool blas_launch_status = + stream + ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } - CHECK(padding_rows >= 0 && padding_cols >= 0) - << "Negative row or col paddings: (" << padding_rows << ", " - << padding_cols << ")"; - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(dims.batch_size) - .set_height(GetTensorDim(compatible_input_shape, data_format_, 'H')) - .set_width(GetTensorDim(compatible_input_shape, data_format_, 'W')) - .set_feature_map_count(dims.in_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(dims.batch_size) - .set_height(dims.spatial_dims[0].output_size) - .set_width(dims.spatial_dims[1].output_size) - .set_feature_map_count(dims.out_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) - .set_input_filter_width(dims.spatial_dims[1].filter_size) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) - .set_horizontal_filter_stride(dims.spatial_dims[1].stride) - .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); - - // NOTE(keveman): - // cuDNN only supports the following layouts : - // Input : B x D x R x C - // Filter : OD x ID x R x C - // Whereas, we have - // Input : B x R x C x D - // Filter : R x C x ID x OD - // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) - // The first TransformDepth performs - // (B x R x C x D) => (B x D x R x C). - // Since the tensor returned from cuDNN is B x D x R x C also, - // the second TransformDepth performs - // (B x D x R x C) => (B x R x C x D). - Tensor transformed_filter; - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<T>::value, - TensorShape({dims.out_depth, dims.in_depth, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[1].filter_size}), - &transformed_filter)); - - functor::TransformFilter<Device, T, int, 4>()( - context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()), - To32Bit(transformed_filter.tensor<T, 4>())); - - Tensor transformed_out_backprop; - if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = ShapeFromFormat( - FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, - dims.spatial_dims[1].output_size, dims.out_depth); - if (dims.out_depth > 1) { - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<T>::value, nchw_shape, - &transformed_out_backprop)); - functor::NHWCToNCHW<Device, T, 4>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), - transformed_out_backprop.tensor<T, 4>()); - } else { - // If depth <= 1, then just reshape. - CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); - } + TensorShape compatible_input_shape; + if (rows_odd || cols_odd) { + // If a padding dimension is odd, we have one more element on the right + // side or the bottom side. This is unsupported in cudnn. Therefore, + // we pad that extra element and make it compatible. + compatible_input_shape = ShapeFromFormat( + data_format, dims.batch_size, + dims.spatial_dims[0].input_size + rows_odd, + dims.spatial_dims[1].input_size + cols_odd, dims.in_depth); + } else { + compatible_input_shape = input_shape; + } + + CHECK(padding_rows >= 0 && padding_cols >= 0) + << "Negative row or col paddings: (" << padding_rows << ", " + << padding_cols << ")"; + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(dims.batch_size) + .set_height(GetTensorDim(compatible_input_shape, data_format, 'H')) + .set_width(GetTensorDim(compatible_input_shape, data_format, 'W')) + .set_feature_map_count(dims.in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(dims.batch_size) + .set_height(dims.spatial_dims[0].output_size) + .set_width(dims.spatial_dims[1].output_size) + .set_feature_map_count(dims.out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) + .set_input_filter_width(dims.spatial_dims[1].filter_size) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) + .set_horizontal_filter_stride(dims.spatial_dims[1].stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + // NOTE(keveman): + // cuDNN only supports the following layouts : + // Input : B x D x R x C + // Filter : OD x ID x R x C + // Whereas, we have + // Input : B x R x C x D + // Filter : R x C x ID x OD + // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) + // The first TransformDepth performs + // (B x R x C x D) => (B x D x R x C). + // Since the tensor returned from cuDNN is B x D x R x C also, + // the second TransformDepth performs + // (B x D x R x C) => (B x R x C x D). + Tensor transformed_filter; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[1].filter_size}), + &transformed_filter)); + + functor::TransformFilter<GPUDevice, T, int, 4>()( + ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), + To32Bit(transformed_filter.tensor<T, 4>())); + + Tensor transformed_out_backprop; + if (data_format == FORMAT_NHWC) { + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, + dims.spatial_dims[1].output_size, dims.out_depth); + if (dims.out_depth > 1) { + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, + &transformed_out_backprop)); + functor::NHWCToNCHW<GPUDevice, T, 4>()( + ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(), + transformed_out_backprop.tensor<T, 4>()); } else { - transformed_out_backprop = out_backprop; + // If depth <= 1, then just reshape. + CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); } + } else { + transformed_out_backprop = out_backprop; + } - Tensor pre_transformed_in_backprop; - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum<T>::value, - ShapeFromFormat( - FORMAT_NCHW, - GetTensorDim(compatible_input_shape, data_format_, 'N'), - GetTensorDim(compatible_input_shape, data_format_, 'H'), - GetTensorDim(compatible_input_shape, data_format_, 'W'), - GetTensorDim(compatible_input_shape, data_format_, 'C')), - &pre_transformed_in_backprop)); - - auto out_backprop_ptr = - AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), - transformed_out_backprop.template flat<T>().size()); - auto filter_ptr = - AsDeviceMemory(transformed_filter.template flat<T>().data(), - transformed_filter.template flat<T>().size()); - auto in_backprop_ptr = - AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(), - pre_transformed_in_backprop.template flat<T>().size()); - - static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( - "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default - ); - CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, - context); - int device_id = stream->parent()->device_ordinal(); - DataType dtype = out_backprop.dtype(); - ConvParameters conv_parameters = { - dims.batch_size, // batch - dims.in_depth, // in_depths - {{input_desc.height(), // in_rows - input_desc.width()}}, // in_cols - dims.out_depth, // out_depths - {{dims.spatial_dims[0].filter_size, // filter_rows - dims.spatial_dims[1].filter_size}}, // filter_cols - {{dims.spatial_dims[0].stride, // stride_rows - dims.spatial_dims[1].stride}}, // stride_cols - {{padding_rows, // padding_rows - padding_cols}}, // padding_cols - dtype, // tensor data type - device_id, // device_id - }; - AlgorithmConfig algorithm_config; - if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find( - conv_parameters, &algorithm_config)) { - std::vector<AlgorithmType> algorithms; - CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; - for (auto profile_algorithm : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, - context); - ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_desc, filter_ptr, output_desc, out_backprop_ptr, - conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, - AlgorithmConfig(profile_algorithm), &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + Tensor pre_transformed_in_backprop; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum<T>::value, + ShapeFromFormat( + FORMAT_NCHW, + GetTensorDim(compatible_input_shape, data_format, 'N'), + GetTensorDim(compatible_input_shape, data_format, 'H'), + GetTensorDim(compatible_input_shape, data_format, 'W'), + GetTensorDim(compatible_input_shape, data_format, 'C')), + &pre_transformed_in_backprop)); + + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), + transformed_out_backprop.template flat<T>().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat<T>().data(), + transformed_filter.template flat<T>().size()); + auto in_backprop_ptr = + AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(), + pre_transformed_in_backprop.template flat<T>().size()); + + static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default + ); + CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx); + int device_id = stream->parent()->device_ordinal(); + DataType dtype = out_backprop.dtype(); + ConvParameters conv_parameters = { + dims.batch_size, // batch + dims.in_depth, // in_depths + {{input_desc.height(), // in_rows + input_desc.width()}}, // in_cols + dims.out_depth, // out_depths + {{dims.spatial_dims[0].filter_size, // filter_rows + dims.spatial_dims[1].filter_size}}, // filter_cols + {{dims.spatial_dims[0].stride, // stride_rows + dims.spatial_dims[1].stride}}, // stride_cols + {{padding_rows, // padding_rows + padding_cols}}, // padding_cols + dtype, // tensor data type + device_id, // device_id + }; + AlgorithmConfig algorithm_config; + if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find( + conv_parameters, &algorithm_config)) { + std::vector<AlgorithmType> algorithms; + CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); + ProfileResult best_result; + ProfileResult best_result_no_scratch; + for (auto profile_algorithm : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, + ctx); + ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, + conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, + AlgorithmConfig(profile_algorithm), &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; } } } - OP_REQUIRES(context, - best_result.is_valid() || best_result_no_scratch.is_valid(), - errors::NotFound("No algorithm worked!")); - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - if (best_result_no_scratch.is_valid()) { - algorithm_config.set_algorithm_no_scratch( - best_result_no_scratch.algorithm()); - } - AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters, - algorithm_config); - } - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_desc, filter_ptr, output_desc, out_backprop_ptr, - conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, - algorithm_config, nullptr) - .ok(); - - if (!cudnn_launch_status) { - context->SetStatus(errors::Internal( - "cuDNN Backward Data function launch failure : input shape(", - input_shape.DebugString(), ") filter shape(", - filter_shape.DebugString(), ")")); - return; } - - if (rows_odd || cols_odd) { - Tensor in_backprop_remove_padding; - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum<T>::value, - ShapeFromFormat(FORMAT_NCHW, - GetTensorDim(input_shape, data_format_, 'N'), - GetTensorDim(input_shape, data_format_, 'H'), - GetTensorDim(input_shape, data_format_, 'W'), - GetTensorDim(input_shape, data_format_, 'C')), - &in_backprop_remove_padding)); - - // Remove the padding for odd rows or cols. - functor::PadInput<GPUDevice, T, int, 4>()( - context->template eigen_device<GPUDevice>(), - To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop) - .tensor<T, 4>()), - {{0, 0}}, {{-rows_odd, -cols_odd}}, - To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW); - - pre_transformed_in_backprop = in_backprop_remove_padding; + OP_REQUIRES(ctx, + best_result.is_valid() || best_result_no_scratch.is_valid(), + errors::NotFound("No algorithm worked!")); + if (best_result.is_valid()) { + algorithm_config.set_algorithm(best_result.algorithm()); } - - if (data_format_ == FORMAT_NHWC) { - auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::NCHWToNHWC<Device, T, 4>()( - context->eigen_device<Device>(), - toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(), - in_backprop->tensor<T, 4>()); - } else { - *in_backprop = pre_transformed_in_backprop; + if (best_result_no_scratch.is_valid()) { + algorithm_config.set_algorithm_no_scratch( + best_result_no_scratch.algorithm()); } + AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters, + algorithm_config); + } + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc, + input_desc, &in_backprop_ptr, &scratch_allocator, + algorithm_config, nullptr) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN Backward Data function launch failure : input shape(", + input_shape.DebugString(), ") filter shape(", + filter_shape.DebugString(), ")")); + return; } - private: - std::vector<int32> strides_; - Padding padding_; - bool use_cudnn_; - TensorFormat data_format_; - bool cudnn_use_autotune_; + if (rows_odd || cols_odd) { + Tensor in_backprop_remove_padding; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum<T>::value, + ShapeFromFormat(FORMAT_NCHW, + GetTensorDim(input_shape, data_format, 'N'), + GetTensorDim(input_shape, data_format, 'H'), + GetTensorDim(input_shape, data_format, 'W'), + GetTensorDim(input_shape, data_format, 'C')), + &in_backprop_remove_padding)); + + // Remove the padding for odd rows or cols. + functor::PadInput<GPUDevice, T, int, 4>()( + ctx->template eigen_device<GPUDevice>(), + To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop) + .tensor<T, 4>()), + {{0, 0}}, {{-rows_odd, -cols_odd}}, + To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW); + + pre_transformed_in_backprop = in_backprop_remove_padding; + } - TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp); -}; + if (data_format == FORMAT_NHWC) { + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::NCHWToNHWC<GPUDevice, T, 4>()( + ctx->eigen_device<GPUDevice>(), + toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(), + in_backprop->tensor<T, 4>()); + } else { + *in_backprop = pre_transformed_in_backprop; + } +} // Forward declarations of the functor specializations for GPU. namespace functor { diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h index 3ea9510afb..2926bb3a86 100644 --- a/tensorflow/core/kernels/conv_grad_ops.h +++ b/tensorflow/core/kernels/conv_grad_ops.h @@ -168,6 +168,43 @@ limitations under the License. namespace tensorflow { +// Forward declaration. +class OpKernelContext; + +template <typename Device, typename T> +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, + int row_stride, int col_stride, const Padding& padding, + Tensor* in_backprop, TensorFormat data_format); +}; + +template <typename Device, typename T> +struct LaunchConv2DBackpropFilterOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_stride, int col_stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format); +}; + +#ifdef GOOGLE_CUDA +template <typename T> +struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); +}; + +template <typename T> +struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_stride, int col_stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format); +}; +#endif // GOOGLE_CUDA + // Information about a single spatial dimension for a convolution // backpropagation. struct ConvBackpropSpatialDimension { diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 2c77a38952..9de8642d41 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -58,10 +58,10 @@ typedef Eigen::GpuDevice GPUDevice; namespace { template <typename Device, typename T> struct LaunchGeneric { - static void launch(OpKernelContext* ctx, const Tensor& input, - const Tensor& filter, int row_stride, int col_stride, - const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const Tensor& input, + const Tensor& filter, int row_stride, int col_stride, + const Padding& padding, Tensor* output, + TensorFormat data_format) { CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only " "supports NHWC tensor format for now."; if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && @@ -86,8 +86,7 @@ struct LaunchGeneric { filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}), dim_pair); } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && - padding == Eigen::PADDING_VALID) { + filter.dim_size(1) == input.dim_size(2) && padding == VALID) { // If the input data and filter have the same height/width, // the 2D convolution is reduced to matrix multiplication. const int k = // Length of reduction dimension. @@ -104,28 +103,26 @@ struct LaunchGeneric { functor::SpatialConvolution<Device, T>()( ctx->eigen_device<Device>(), output->tensor<T, 4>(), input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride, - padding); + BrainPadding2EigenPadding(padding)); } } }; } // namespace template <typename T> -class LaunchConv2DOp<CPUDevice, T> { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format) { +struct LaunchConv2DOp<CPUDevice, T> { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format) { if (data_format != FORMAT_NHWC) { ctx->SetStatus( errors::Unimplemented("Generic conv implementation only supports " "NHWC tensor format for now.")); return; } - LaunchGeneric<CPUDevice, T>::launch(ctx, input, filter, row_stride, - col_stride, padding, output, - data_format); + LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride, + padding, output, data_format); } }; @@ -387,9 +384,8 @@ class Conv2DOp : public BinaryOp<T> { return; } - launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter, - stride_rows, stride_cols, - BrainPadding2EigenPadding(padding_), output, data_format_); + launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, + stride_rows, stride_cols, padding_, output, data_format_); } private: @@ -445,10 +441,10 @@ typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters, AutoTuneConv; template <typename T> -void LaunchConv2DOp<GPUDevice, T>::launch( +void LaunchConv2DOp<GPUDevice, T>::operator()( OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& input_param, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, + int col_stride, const Padding& padding, Tensor* output, TensorFormat data_format) { using perftools::gputools::dnn::AlgorithmConfig; using perftools::gputools::dnn::AlgorithmType; @@ -492,8 +488,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch( } return; } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && - padding == Eigen::PADDING_VALID && data_format == FORMAT_NHWC) { + filter.dim_size(1) == input.dim_size(2) && padding == VALID && + data_format == FORMAT_NHWC) { // The input data and filter have the same height/width, so call cublas // directly. const uint64 m = input.dim_size(0); diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index 60091fc27f..e29271dff2 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -32,14 +32,23 @@ namespace tensorflow { class OpKernelContext; template <typename Device, typename T> -class LaunchConv2DOp { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format); +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); }; +#ifdef GOOGLE_CUDA +template <typename T> +struct LaunchConv2DOp<Eigen::GpuDevice, T> { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); +}; +#endif // GOOGLE_CUDA + // Used to keep track of persistent memory buffers used within the op. // It uses malloc and free to avoid the time cost of initializing the memory. template <class T, size_t size> @@ -55,17 +64,6 @@ struct Im2ColBufferResource : public ResourceBase { string DebugString() { return "Im2ColBufferResource"; } }; -#ifdef GOOGLE_CUDA -template <typename T> -class LaunchConv2DOp<Eigen::GpuDevice, T> { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format); -}; -#endif // GOOGLE_CUDA - } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONV_OPS_H diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 00d7f56408..9804d7d38e 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -361,19 +361,15 @@ static void ComputeBackpropInput(const DepthwiseArgs& args, } } -// Kernels to compute the input backprop for depthwise convolution. -template <typename Device, typename T> -struct LaunchDepthwiseConvBackpropInputOp; - // Computes the depthwise conv2d backprop input of 'out_backprop' by // 'depthwise_filter' and stores the result in 'in_backprop'. template <typename T> struct LaunchDepthwiseConvBackpropInputOp<CPUDevice, T> { typedef typename Eigen::internal::packet_traits<T>::type Packet; - static void launch(OpKernelContext* ctx, const DepthwiseArgs& args, - const T* out_backprop, const T* depthwise_filter, - T* in_backprop, TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* depthwise_filter, + T* in_backprop, TensorFormat data_format) { OP_REQUIRES( ctx, data_format == FORMAT_NHWC, errors::Unimplemented( @@ -514,27 +510,8 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args, #if GOOGLE_CUDA -template <typename T> -struct DepthwiseConv2dBackpropInputGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* filter, T* in_backprop, - TensorFormat data_format); -}; - -template <typename T> -struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, T> { - static void launch(OpKernelContext* ctx, const DepthwiseArgs args, - const T* out_backprop, const T* filter, T* in_backprop, - TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device<GPUDevice>(); - DepthwiseConv2dBackpropInputGPULaunch<T>().Run( - d, args, out_backprop, filter, in_backprop, data_format); - auto stream = ctx->op_device_context()->stream(); - OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for " - "DepthwiseConv2dBackpropInp" - "utGPULaunch failed")); - } -}; +extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>; +extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>; #endif // GOOGLE_CUDA @@ -598,7 +575,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { if (input_shape.num_elements() == 0) { return; } - LaunchDepthwiseConvBackpropInputOp<Device, T>::launch( + LaunchDepthwiseConvBackpropInputOp<Device, T>()( context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr, data_format_); } @@ -744,9 +721,9 @@ template <typename T> struct LaunchDepthwiseConvBackpropFilterOp<CPUDevice, T> { typedef typename Eigen::internal::packet_traits<T>::type Packet; - static void launch(OpKernelContext* ctx, const DepthwiseArgs& args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format) { OP_REQUIRES( ctx, data_format == FORMAT_NHWC, errors::Unimplemented( @@ -907,35 +884,8 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args, #if GOOGLE_CUDA -template <typename T> -struct DepthwiseConv2dBackpropFilterGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format); -}; - -template <typename T> -struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T> { - static void launch(OpKernelContext* ctx, const DepthwiseArgs args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device<GPUDevice>(); - auto stream = ctx->op_device_context()->stream(); - - // Initialize the results to 0. - int num_filter_backprop = - args.filter_rows * args.filter_cols * args.out_depth; - perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop, - num_filter_backprop); - stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T)); - - DepthwiseConv2dBackpropFilterGPULaunch<T>().Run( - d, args, out_backprop, input, filter_backprop, data_format); - OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for " - "DepthwiseConv2dBackpropFil" - "terGPULaunch failed")); - } -}; +extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>; +extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>; #endif // GOOGLE_CUDA @@ -1001,7 +951,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel { if (filter_shape.num_elements() == 0) { return; } - LaunchDepthwiseConvBackpropFilterOp<Device, T>::launch( + LaunchDepthwiseConvBackpropFilterOp<Device, T>()( context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr, data_format_); } diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 3c01546d8d..bbeeaf7895 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -54,9 +54,6 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template <typename Device, typename T> -struct LaunchDepthwiseConvOp; - // Computes the vectorized product of 'input_buffer' and 'filter' and stores // result in 'output' at location specified by 'out_r' and 'out_c'. // @@ -156,9 +153,9 @@ template <typename T> struct LaunchDepthwiseConvOp<CPUDevice, T> { typedef typename Eigen::internal::packet_traits<T>::type Packet; - static void launch(OpKernelContext* ctx, const DepthwiseArgs& args, - const T* input, const T* depthwise_filter, T* output, - TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* depthwise_filter, T* output, + TensorFormat data_format) { OP_REQUIRES( ctx, data_format == FORMAT_NHWC, errors::Unimplemented( @@ -248,27 +245,9 @@ extern template class LaunchConv2DOp<CPUDevice, float>; #if GOOGLE_CUDA -template <typename T> -struct DepthwiseConv2dGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input, - const T* filter, T* output, TensorFormat data_format); -}; - -template <typename T> -struct LaunchDepthwiseConvOp<GPUDevice, T> { - static void launch(OpKernelContext* ctx, const DepthwiseArgs args, - const T* input, const T* filter, T* output, - TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device<GPUDevice>(); - DepthwiseConv2dGPULaunch<T>().Run(d, args, input, filter, output, - data_format); - auto stream = ctx->op_device_context()->stream(); - OP_REQUIRES( - ctx, stream->ok(), - errors::Internal( - "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed")); - } -}; +// Extern template instantiated in depthwise_conv_op_gpu.cc. +extern template struct LaunchDepthwiseConvOp<GPUDevice, float>; +extern template struct LaunchDepthwiseConvOp<GPUDevice, double>; // Extern template instantiated in conv_ops.cc. extern template class LaunchConv2DOp<GPUDevice, float>; @@ -393,9 +372,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { // If in_depth==1, this operation is just a standard convolution, so // invoke that op. if (std::is_same<T, float>::value && in_depth == 1) { - launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter, - stride_, stride_, BrainPadding2EigenPadding(padding_), - output, data_format_); + launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, + stride_, stride_, padding_, output, data_format_); return; } @@ -417,8 +395,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { auto input_ptr = input.template flat<T>().data(); auto filter_ptr = filter.template flat<T>().data(); auto output_ptr = output->template flat<T>().data(); - LaunchDepthwiseConvOp<Device, T>::launch( - context, args, input_ptr, filter_ptr, output_ptr, data_format_); + LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr, + output_ptr, data_format_); } private: diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h index 1960b02bbe..aa5b5c76f6 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.h +++ b/tensorflow/core/kernels/depthwise_conv_op.h @@ -56,6 +56,53 @@ struct DepthwiseArgs { out_depth(0) {} }; +// Forward declaration. +class OpKernelContext; + +template <typename Device, typename T> +struct LaunchDepthwiseConvOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format); +}; + +template <typename Device, typename T> +struct LaunchDepthwiseConvBackpropInputOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* filter, T* in_backprop, + TensorFormat data_format); +}; + +template <typename Device, typename T> +struct LaunchDepthwiseConvBackpropFilterOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format); +}; + +#if GOOGLE_CUDA +template <typename T> +struct LaunchDepthwiseConvOp<Eigen::GpuDevice, T> { + void operator()(OpKernelContext* ctx, const DepthwiseArgs args, + const T* input, const T* filter, T* output, + TensorFormat data_format); +}; + +template <typename T> +struct LaunchDepthwiseConvBackpropInputOp<Eigen::GpuDevice, T> { + void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* filter, T* in_backprop, + TensorFormat data_format); +}; + +template <typename T> +struct LaunchDepthwiseConvBackpropFilterOp<Eigen::GpuDevice, T> { + void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format); +}; +#endif + } // namespace tensorflow namespace tensorflow { diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index f63a99a730..fcfcd188d2 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -17,6 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/cuda_kernel_helper.h" @@ -689,21 +690,27 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, // A simple launch pad to launch the Cuda kernel for depthwise convolution. template <typename T> -struct DepthwiseConv2dGPULaunch { - static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input, - const T* filter, T* output, TensorFormat data_format) { - if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output, +void LaunchDepthwiseConvOp<GPUDevice, T>::operator()(OpKernelContext* ctx, + const DepthwiseArgs args, + const T* input, + const T* filter, T* output, + TensorFormat data_format) { + const GPUDevice& d = ctx->eigen_device<GPUDevice>(); + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output, + data_format); + } else { + LaunchDepthwiseConv2dGPU<T, -1, -1>(d, args, input, filter, output, data_format); - } else { - LaunchDepthwiseConv2dGPU<T, -1, -1>(d, args, input, filter, output, - data_format); - } } -}; + auto stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream->ok(), + errors::Internal( + "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed")); +} -template struct DepthwiseConv2dGPULaunch<float>; -template struct DepthwiseConv2dGPULaunch<double>; +template struct LaunchDepthwiseConvOp<GPUDevice, float>; +template struct LaunchDepthwiseConvOp<GPUDevice, double>; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input. template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, @@ -893,22 +900,26 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, // A simple launch pad to launch the Cuda kernel for depthwise convolution. template <typename T> -struct DepthwiseConv2dBackpropInputGPULaunch { - static void Run(const GpuDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* filter, T* in_backprop, - TensorFormat data_format) { - if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>( - d, args, out_backprop, filter, in_backprop, data_format); - } else { - LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>( - d, args, out_backprop, filter, in_backprop, data_format); - } +void LaunchDepthwiseConvBackpropInputOp<GPUDevice, T>::operator()( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* filter, T* in_backprop, TensorFormat data_format) { + const GPUDevice& d = ctx->eigen_device<GPUDevice>(); + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>( + d, args, out_backprop, filter, in_backprop, data_format); + } else { + LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>( + d, args, out_backprop, filter, in_backprop, data_format); } -}; + auto stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream->ok(), + errors::Internal("Launch of gpu kernel for " + "DepthwiseConv2dBackpropInp" + "utGPULaunch failed")); +} -template struct DepthwiseConv2dBackpropInputGPULaunch<float>; -template struct DepthwiseConv2dBackpropInputGPULaunch<double>; +template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>; +template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template <typename T, int kKnownFilterWidth, int kKnownFilterHeight, @@ -1580,21 +1591,33 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, // A simple launch pad to launch the Cuda kernel for depthwise convolution. template <typename T> -struct DepthwiseConv2dBackpropFilterGPULaunch { - static void Run(const GpuDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format) { - if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>( - d, args, out_backprop, input, filter_backprop, data_format); - } else { - LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>( - d, args, out_backprop, input, filter_backprop, data_format); - } +void LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T>::operator()( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { + const GPUDevice& d = ctx->eigen_device<GPUDevice>(); + auto stream = ctx->op_device_context()->stream(); + + // Initialize the results to 0. + int num_filter_backprop = + args.filter_rows * args.filter_cols * args.out_depth; + perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop, + num_filter_backprop); + stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T)); + + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>( + d, args, out_backprop, input, filter_backprop, data_format); + } else { + LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>( + d, args, out_backprop, input, filter_backprop, data_format); } -}; + OP_REQUIRES(ctx, stream->ok(), + errors::Internal("Launch of gpu kernel for " + "DepthwiseConv2dBackpropFil" + "terGPULaunch failed")); +} -template struct DepthwiseConv2dBackpropFilterGPULaunch<float>; -template struct DepthwiseConv2dBackpropFilterGPULaunch<double>; +template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>; +template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>; } // namespace tensorflow #endif // GOOGLE_CUDA |