aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-09-26 16:40:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 16:46:17 -0700
commit3b9c747d71f30c6a59f6529f8475d7f56a86a7c5 (patch)
tree96770ad0dbefecc78f27414a622acd59856ad3f7 /tensorflow
parent3ab16ebce6a0a9ce20120c3c2dd1f1a8cf5b2ad8 (diff)
Extract Conv2D dimensions parsing and validation into helper functions.
PiperOrigin-RevId: 214691838
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/conv_ops.cc321
-rw-r--r--tensorflow/core/kernels/conv_ops.h44
2 files changed, 231 insertions, 134 deletions
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 717a9f40a9..78856c4a99 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> {
};
#endif
+#define TF_REQUIRES(EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
+ } while (false)
+
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params) {
+ TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
+ TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
+ TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
+ string data_format_string;
+ TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
+ TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
+ errors::InvalidArgument("Invalid data format"));
+
+ const auto& strides = params->strides;
+ const auto& dilations = params->dilations;
+ const auto& data_format = params->data_format;
+
+ TF_REQUIRES(dilations.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ TF_REQUIRES(strides.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides, data_format, 'N');
+ const int64 stride_c = GetTensorDim(strides, data_format, 'C');
+ const int64 stride_h = GetTensorDim(strides, data_format, 'H');
+ const int64 stride_w = GetTensorDim(strides, data_format, 'W');
+ TF_REQUIRES(
+ stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ TF_REQUIRES(stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
+ const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
+ const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
+ const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
+ TF_REQUIRES(
+ dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ TF_REQUIRES(
+ dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ return Status::OK();
+}
+
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions) {
+ // Check that 2D convolution input and filter have exactly 4 dimensions.
+ TF_REQUIRES(input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ TF_REQUIRES(filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+ for (int i = 0; i < 3; i++) {
+ TF_REQUIRES(
+ FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ // The last dimension for input is in_depth. Check that it is the same as the
+ // filter's in_depth or it is evenly divisible by filter's in_depth.
+ const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
+ const int64 patch_depth_raw = filter.dim_size(2);
+ TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input depth too large"));
+ TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Patch depth too large"));
+ const int in_depth = static_cast<int>(in_depth_raw);
+ const int patch_depth = static_cast<int>(patch_depth_raw);
+ TF_REQUIRES(in_depth % patch_depth == 0,
+ errors::InvalidArgument(
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
+
+ // The last dimension for filter is out_depth.
+ const int out_depth = static_cast<int>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
+ TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
+ const int input_rows = static_cast<int>(input_rows_raw);
+ const int filter_rows = static_cast<int>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
+ TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
+ const int input_cols = static_cast<int>(input_cols_raw);
+ const int filter_cols = static_cast<int>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
+ TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
+ const int batch = static_cast<int>(batch_raw);
+
+ // Take the stride and dilation from the second and third dimensions only (we
+ // do not support striding or dilation on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
+ const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
+ const int dilation_rows =
+ GetTensorDim(params.dilations, params.data_format, 'H');
+ const int dilation_cols =
+ GetTensorDim(params.dilations, params.data_format, 'W');
+
+ // Compute windowed output sizes for rows and columns.
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
+ &out_rows, &pad_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
+ &out_cols, &pad_cols));
+
+ dimensions->batch = batch;
+ dimensions->input_rows = input_rows;
+ dimensions->input_cols = input_cols;
+ dimensions->in_depth = in_depth;
+ dimensions->filter_rows = filter_rows;
+ dimensions->filter_cols = filter_cols;
+ dimensions->patch_depth = patch_depth;
+ dimensions->out_depth = out_depth;
+ dimensions->stride_rows = stride_rows;
+ dimensions->stride_cols = stride_cols;
+ dimensions->dilation_rows = dilation_rows;
+ dimensions->dilation_cols = dilation_cols;
+ dimensions->out_rows = out_rows;
+ dimensions->out_cols = out_cols;
+ dimensions->pad_rows = pad_rows;
+ dimensions->pad_cols = pad_cols;
+
+ return Status::OK();
+}
+
+#undef TF_REQUIRES
+
template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
+
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- OP_REQUIRES(context, strides_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
- const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
- const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
- const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
- const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
- OP_REQUIRES(
- context, stride_n == 1 && stride_c == 1,
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
- errors::InvalidArgument(
- "Row and column strides should be larger than 0."));
-
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
-
const Tensor& input = context->input(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, out_depth]
const Tensor& filter = context->input(1);
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- OP_REQUIRES(context, filter.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
- filter.shape().DebugString()));
-
- for (int i = 0; i < 3; i++) {
- OP_REQUIRES(
- context,
- FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
- }
+ Conv2DDimensions dimensions;
+ OP_REQUIRES_OK(context,
+ ComputeConv2DDimension(params_, input, filter, &dimensions));
- // The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth or be evenly divisible by filter's in_depth.
- const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- const int64 patch_depth = filter.dim_size(2);
- OP_REQUIRES(context, in_depth % patch_depth == 0,
- errors::InvalidArgument(
- "input depth must be evenly divisible by filter depth: ",
- in_depth, " vs ", patch_depth));
-
- // The last dimension for filter is out_depth.
- const int out_depth = static_cast<int>(filter.dim_size(3));
-
- // The second dimension for input is rows/height.
- // The first dimension for filter is rows/height.
- const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input rows too large"));
- const int input_rows = static_cast<int>(input_rows_raw);
- const int filter_rows = static_cast<int>(filter.dim_size(0));
-
- // The third dimension for input is columns/width.
- // The second dimension for filter is columns/width.
- const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input cols too large"));
- const int input_cols = static_cast<int>(input_cols_raw);
- const int filter_cols = static_cast<int>(filter.dim_size(1));
-
- // The first dimension for input is batch.
- const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
- OP_REQUIRES(context,
- FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("batch is too large"));
- const int batch = static_cast<int>(batch_raw);
-
- // For now we take the stride and dilation from the second and third
- // dimensions only (we do not support striding or dilation on the batch or
- // depth dimension).
- const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
-
- const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
- const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
-
- int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_rows, filter_rows, dilation_rows,
- stride_rows, padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_cols, filter_cols, dilation_cols,
- stride_cols, padding_, &out_cols, &pad_cols));
- TensorShape out_shape =
- ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ TensorShape out_shape = ShapeFromFormat(
+ params_.data_format, dimensions.batch, dimensions.out_rows,
+ dimensions.out_cols, dimensions.out_depth);
// Output tensor is of the following dimensions:
// [ in_batch, out_rows, out_cols, out_depth ]
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "Conv2D: in_depth = " << in_depth
- << ", patch_depth = " << patch_depth
- << ", input_cols = " << input_cols
- << ", filter_cols = " << filter_cols
- << ", input_rows = " << input_rows
- << ", filter_rows = " << filter_rows
- << ", stride_rows = " << stride_rows
- << ", stride_cols = " << stride_cols
- << ", dilation_rows = " << dilation_rows
- << ", dilation_cols = " << dilation_cols
- << ", out_depth = " << out_depth;
+ VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
+ << ", patch_depth = " << dimensions.patch_depth
+ << ", input_cols = " << dimensions.input_cols
+ << ", filter_cols = " << dimensions.filter_cols
+ << ", input_rows = " << dimensions.input_rows
+ << ", filter_rows = " << dimensions.filter_rows
+ << ", stride_rows = " << dimensions.stride_rows
+ << ", stride_cols = " << dimensions.stride_cols
+ << ", dilation_rows = " << dimensions.dilation_rows
+ << ", dilation_cols = " << dimensions.dilation_cols
+ << ", out_depth = " << dimensions.out_depth;
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
@@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> {
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
if (LaunchXsmmConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
#endif
if (LaunchDeepConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
- output, data_format_);
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, params_.padding,
+ output, params_.data_format);
}
private:
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
+ Conv2DParameters params_;
bool use_cudnn_;
- Padding padding_;
- TensorFormat data_format_;
- LaunchConv2DOp<Device, T> launcher_;
bool cudnn_use_autotune_;
+ LaunchConv2DOp<Device, T> launcher_;
+
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
};
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index adf4601b43..7ec878e0b2 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase {
string DebugString() { return "Im2ColBufferResource"; }
};
+// Convolution parameters specified by Op attributes.
+struct Conv2DParameters {
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Convolution dimensions inferred from parameters, input and filter tensors.
+struct Conv2DDimensions {
+ int batch;
+ int input_rows;
+ int input_cols;
+ int in_depth;
+
+ int filter_rows;
+ int filter_cols;
+ int patch_depth;
+ int out_depth;
+
+ int stride_rows;
+ int stride_cols;
+
+ int dilation_rows;
+ int dilation_cols;
+
+ int64 out_rows;
+ int64 out_cols;
+ int64 pad_rows;
+ int64 pad_cols;
+};
+
+// Initializes and validates Conv2D parameters configured by OpKernel
+// attributes.
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params);
+
+// Computes and validates convolutions dimensions from Conv2D parameters. If
+// parameters are valid, dimensions will be updated with derived convolution
+// dimensions, otherwise error will be returned.
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_