diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-09-26 16:40:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 16:46:17 -0700 |
commit | 3b9c747d71f30c6a59f6529f8475d7f56a86a7c5 (patch) | |
tree | 96770ad0dbefecc78f27414a622acd59856ad3f7 /tensorflow | |
parent | 3ab16ebce6a0a9ce20120c3c2dd1f1a8cf5b2ad8 (diff) |
Extract Conv2D dimensions parsing and validation into helper functions.
PiperOrigin-RevId: 214691838
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/kernels/conv_ops.cc | 321 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops.h | 44 |
2 files changed, 231 insertions, 134 deletions
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 717a9f40a9..78856c4a99 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> { }; #endif +#define TF_REQUIRES(EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \ + } while (false) + +Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params) { + TF_RETURN_IF_ERROR(context->GetAttr("dilations", ¶ms->dilations)); + TF_RETURN_IF_ERROR(context->GetAttr("strides", ¶ms->strides)); + TF_RETURN_IF_ERROR(context->GetAttr("padding", ¶ms->padding)); + string data_format_string; + TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string)); + TF_REQUIRES(FormatFromString(data_format_string, ¶ms->data_format), + errors::InvalidArgument("Invalid data format")); + + const auto& strides = params->strides; + const auto& dilations = params->dilations; + const auto& data_format = params->data_format; + + TF_REQUIRES(dilations.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + TF_REQUIRES(strides.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int64 stride_n = GetTensorDim(strides, data_format, 'N'); + const int64 stride_c = GetTensorDim(strides, data_format, 'C'); + const int64 stride_h = GetTensorDim(strides, data_format, 'H'); + const int64 stride_w = GetTensorDim(strides, data_format, 'W'); + TF_REQUIRES( + stride_n == 1 && stride_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + TF_REQUIRES(stride_h > 0 && stride_w > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + + const int64 dilation_n = GetTensorDim(dilations, data_format, 'N'); + const int64 dilation_c = GetTensorDim(dilations, data_format, 'C'); + const int64 dilation_h = GetTensorDim(dilations, data_format, 'H'); + const int64 dilation_w = GetTensorDim(dilations, data_format, 'W'); + TF_REQUIRES( + dilation_n == 1 && dilation_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + TF_REQUIRES( + dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + + return Status::OK(); +} + +Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions) { + // Check that 2D convolution input and filter have exactly 4 dimensions. + TF_REQUIRES(input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().DebugString())); + TF_REQUIRES(filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().DebugString())); + for (int i = 0; i < 3; i++) { + TF_REQUIRES( + FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), + errors::InvalidArgument("filter too large")); + } + + // The last dimension for input is in_depth. Check that it is the same as the + // filter's in_depth or it is evenly divisible by filter's in_depth. + const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C'); + const int64 patch_depth_raw = filter.dim_size(2); + TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input depth too large")); + TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Patch depth too large")); + const int in_depth = static_cast<int>(in_depth_raw); + const int patch_depth = static_cast<int>(patch_depth_raw); + TF_REQUIRES(in_depth % patch_depth == 0, + errors::InvalidArgument( + "input depth must be evenly divisible by filter depth: ", + in_depth, " vs ", patch_depth)); + + // The last dimension for filter is out_depth. + const int out_depth = static_cast<int>(filter.dim_size(3)); + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H'); + TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input rows too large")); + const int input_rows = static_cast<int>(input_rows_raw); + const int filter_rows = static_cast<int>(filter.dim_size(0)); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W'); + TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input cols too large")); + const int input_cols = static_cast<int>(input_cols_raw); + const int filter_cols = static_cast<int>(filter.dim_size(1)); + + // The first dimension for input is batch. + const int64 batch_raw = GetTensorDim(input, params.data_format, 'N'); + TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("batch is too large")); + const int batch = static_cast<int>(batch_raw); + + // Take the stride and dilation from the second and third dimensions only (we + // do not support striding or dilation on the batch or depth dimension). + const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H'); + const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W'); + const int dilation_rows = + GetTensorDim(params.dilations, params.data_format, 'H'); + const int dilation_cols = + GetTensorDim(params.dilations, params.data_format, 'W'); + + // Compute windowed output sizes for rows and columns. + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( + input_rows, filter_rows, dilation_rows, stride_rows, params.padding, + &out_rows, &pad_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( + input_cols, filter_cols, dilation_cols, stride_cols, params.padding, + &out_cols, &pad_cols)); + + dimensions->batch = batch; + dimensions->input_rows = input_rows; + dimensions->input_cols = input_cols; + dimensions->in_depth = in_depth; + dimensions->filter_rows = filter_rows; + dimensions->filter_cols = filter_cols; + dimensions->patch_depth = patch_depth; + dimensions->out_depth = out_depth; + dimensions->stride_rows = stride_rows; + dimensions->stride_cols = stride_cols; + dimensions->dilation_rows = dilation_rows; + dimensions->dilation_cols = dilation_cols; + dimensions->out_rows = out_rows; + dimensions->out_cols = out_cols; + dimensions->pad_rows = pad_rows; + dimensions->pad_cols = pad_cols; + + return Status::OK(); +} + +#undef TF_REQUIRES + template <typename Device, typename T> class Conv2DOp : public BinaryOp<T> { public: explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) { - OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, InitConv2DParameters(context, ¶ms_)); + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); - OP_REQUIRES(context, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); - const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); - const int64 stride_h = GetTensorDim(strides_, data_format_, 'H'); - const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); - OP_REQUIRES( - context, stride_n == 1 && stride_c == 1, - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES(context, stride_h > 0 && stride_w > 0, - errors::InvalidArgument( - "Row and column strides should be larger than 0.")); - - const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); - const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); - const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); - const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); - OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - errors::InvalidArgument( - "Current implementation does not yet support " - "dilations in the batch and depth dimensions.")); - OP_REQUIRES( - context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } void Compute(OpKernelContext* context) override { // Input tensor is of the following dimensions: // [ batch, in_rows, in_cols, in_depth ] - const Tensor& input = context->input(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, in_depth, out_depth] const Tensor& filter = context->input(1); - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - for (int i = 0; i < 3; i++) { - OP_REQUIRES( - context, - FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), - errors::InvalidArgument("filter too large")); - } + Conv2DDimensions dimensions; + OP_REQUIRES_OK(context, + ComputeConv2DDimension(params_, input, filter, &dimensions)); - // The last dimension for input is in_depth. It must be the same as the - // filter's in_depth or be evenly divisible by filter's in_depth. - const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - const int64 patch_depth = filter.dim_size(2); - OP_REQUIRES(context, in_depth % patch_depth == 0, - errors::InvalidArgument( - "input depth must be evenly divisible by filter depth: ", - in_depth, " vs ", patch_depth)); - - // The last dimension for filter is out_depth. - const int out_depth = static_cast<int>(filter.dim_size(3)); - - // The second dimension for input is rows/height. - // The first dimension for filter is rows/height. - const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES( - context, - FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input rows too large")); - const int input_rows = static_cast<int>(input_rows_raw); - const int filter_rows = static_cast<int>(filter.dim_size(0)); - - // The third dimension for input is columns/width. - // The second dimension for filter is columns/width. - const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES( - context, - FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input cols too large")); - const int input_cols = static_cast<int>(input_cols_raw); - const int filter_cols = static_cast<int>(filter.dim_size(1)); - - // The first dimension for input is batch. - const int64 batch_raw = GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES(context, - FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("batch is too large")); - const int batch = static_cast<int>(batch_raw); - - // For now we take the stride and dilation from the second and third - // dimensions only (we do not support striding or dilation on the batch or - // depth dimension). - const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); - const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - - const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); - const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( - input_rows, filter_rows, dilation_rows, - stride_rows, padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( - input_cols, filter_cols, dilation_cols, - stride_cols, padding_, &out_cols, &pad_cols)); - TensorShape out_shape = - ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); + TensorShape out_shape = ShapeFromFormat( + params_.data_format, dimensions.batch, dimensions.out_rows, + dimensions.out_cols, dimensions.out_depth); // Output tensor is of the following dimensions: // [ in_batch, out_rows, out_cols, out_depth ] Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - VLOG(2) << "Conv2D: in_depth = " << in_depth - << ", patch_depth = " << patch_depth - << ", input_cols = " << input_cols - << ", filter_cols = " << filter_cols - << ", input_rows = " << input_rows - << ", filter_rows = " << filter_rows - << ", stride_rows = " << stride_rows - << ", stride_cols = " << stride_cols - << ", dilation_rows = " << dilation_rows - << ", dilation_cols = " << dilation_cols - << ", out_depth = " << out_depth; + VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth + << ", patch_depth = " << dimensions.patch_depth + << ", input_cols = " << dimensions.input_cols + << ", filter_cols = " << dimensions.filter_cols + << ", input_rows = " << dimensions.input_rows + << ", filter_rows = " << dimensions.filter_rows + << ", stride_rows = " << dimensions.stride_rows + << ", stride_cols = " << dimensions.stride_cols + << ", dilation_rows = " << dimensions.dilation_rows + << ", dilation_cols = " << dimensions.dilation_cols + << ", out_depth = " << dimensions.out_depth; // If there is nothing to compute, return. if (out_shape.num_elements() == 0) { @@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> { #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS if (LaunchXsmmConvOp<Device, T>::Run( - context, input, filter, batch, input_rows, input_cols, in_depth, - filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, - output, data_format_)) { + context, input, filter, dimensions.batch, dimensions.input_rows, + dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows, + dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols, + dimensions.out_rows, dimensions.out_cols, dimensions.out_depth, + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, output, + params_.data_format)) { return; } #endif if (LaunchDeepConvOp<Device, T>::Run( - context, input, filter, batch, input_rows, input_cols, in_depth, - filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, - output, data_format_)) { + context, input, filter, dimensions.batch, dimensions.input_rows, + dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows, + dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols, + dimensions.out_rows, dimensions.out_cols, dimensions.out_depth, + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, output, + params_.data_format)) { return; } launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, - dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, - output, data_format_); + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, params_.padding, + output, params_.data_format); } private: - std::vector<int32> dilations_; - std::vector<int32> strides_; + Conv2DParameters params_; bool use_cudnn_; - Padding padding_; - TensorFormat data_format_; - LaunchConv2DOp<Device, T> launcher_; bool cudnn_use_autotune_; + LaunchConv2DOp<Device, T> launcher_; + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp); }; diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index adf4601b43..7ec878e0b2 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase { string DebugString() { return "Im2ColBufferResource"; } }; +// Convolution parameters specified by Op attributes. +struct Conv2DParameters { + std::vector<int32> dilations; + std::vector<int32> strides; + Padding padding; + TensorFormat data_format; +}; + +// Convolution dimensions inferred from parameters, input and filter tensors. +struct Conv2DDimensions { + int batch; + int input_rows; + int input_cols; + int in_depth; + + int filter_rows; + int filter_cols; + int patch_depth; + int out_depth; + + int stride_rows; + int stride_cols; + + int dilation_rows; + int dilation_cols; + + int64 out_rows; + int64 out_cols; + int64 pad_rows; + int64 pad_cols; +}; + +// Initializes and validates Conv2D parameters configured by OpKernel +// attributes. +Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params); + +// Computes and validates convolutions dimensions from Conv2D parameters. If +// parameters are valid, dimensions will be updated with derived convolution +// dimensions, otherwise error will be returned. +Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions); + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ |