diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 13:04:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 13:05:05 -0700 |
commit | 9c50882415cb87a7eb81048d42401c64bf0617ef (patch) | |
tree | c550925b2d9e7f6997ace0e3bb3268572f7066b7 /tensorflow/core/kernels/mkl_conv_ops.h | |
parent | 19cafed2ae69ce5cbc4d2b2fc9176fb4c550040f (diff) | |
parent | 62191da0819b25906c1b2ed96159cfe36ba00383 (diff) |
Merge pull request #21324 from Intel-tensorflow:conv3d
PiperOrigin-RevId: 209032082
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.h | 414 |
1 files changed, 280 insertions, 134 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index 838c06f49d..01cc606f41 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -79,9 +79,16 @@ class MklDnnConvUtil { // For now we take the stride from the second and third dimensions only // (we do not support striding on the batch or depth dimension). CHECK_NOTNULL(strides); - int stride_rows = GetTensorDim(strides_, data_format_, 'H'); - int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - *strides = {stride_rows, stride_cols}; + if (strides_.size() == 4) { + int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + *strides = {stride_rows, stride_cols}; + } else if (strides_.size() == 5) { + int stride_planes = GetTensorDim(strides_, data_format_, '0'); + int stride_rows = GetTensorDim(strides_, data_format_, '1'); + int stride_cols = GetTensorDim(strides_, data_format_, '2'); + *strides = {stride_planes, stride_rows, stride_cols}; + } } // Calculate Convolution dilations @@ -89,13 +96,20 @@ class MklDnnConvUtil { // For now we take the dilation from the second and third dimensions only // (we do not support dilation on the batch or depth dimension). CHECK_NOTNULL(dilations); - int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); - int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); - *dilations = {dilations_rows, dilations_cols}; + if (dilations_.size() == 4) { + int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); + int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); + *dilations = {dilations_rows, dilations_cols}; + } else if (dilations_.size() == 5) { + int dilations_planes = GetTensorDim(dilations_, data_format_, '0'); + int dilations_rows = GetTensorDim(dilations_, data_format_, '1'); + int dilations_cols = GetTensorDim(dilations_, data_format_, '2'); + *dilations = {dilations_planes, dilations_rows, dilations_cols}; + } } // Calculate Convolution input size in MKL-DNN order. MKL-DNN - // requires input in NCHW format. Function does not return anything. + // requires input in NCHW/NCDHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, @@ -113,40 +127,62 @@ class MklDnnConvUtil { int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C'); int input_depth = static_cast<int>(input_depth_raw); - // Input rows/height - int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); - CHECK_BOUNDS(input_rows_raw, "Input rows too large"); - int input_rows = static_cast<int>(input_rows_raw); - - // Input columns/width - int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); - CHECK_BOUNDS(input_cols_raw, "Input cols too large"); - int input_cols = static_cast<int>(input_cols_raw); - // Input batch int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N'); CHECK_BOUNDS(input_batch_raw, "Input batch too large"); int input_batch = static_cast<int>(input_batch_raw); + if (strides_.size() == 4) { // NCHW format for Conv2D + // Input rows/height + int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); + CHECK_BOUNDS(input_rows_raw, "Input rows too large"); + int input_rows = static_cast<int>(input_rows_raw); + + // Input columns/width + int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); + CHECK_BOUNDS(input_cols_raw, "Input cols too large"); + int input_cols = static_cast<int>(input_cols_raw); + + // MKL-DNN always requires input in NCHW format Conv2D. + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; + mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; + mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; + + *input_dims = mkldnn_sizes; + } else if (strides_.size() == 5) { // NCDHW format for Conv3D + // Input planes/third-dimension + int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0'); + CHECK_BOUNDS(input_planes_raw, "Input depth too large"); + int input_planes = static_cast<int>(input_planes_raw); + + // Input rows/height + int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1'); + CHECK_BOUNDS(input_rows_raw, "Input rows too large"); + int input_rows = static_cast<int>(input_rows_raw); + + // Input columns/width + int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2'); + CHECK_BOUNDS(input_cols_raw, "Input cols too large"); + int input_cols = static_cast<int>(input_cols_raw); + + // MKL-DNN always requires input in NCDHW format for Conv3D. + std::vector<int> mkldnn_sizes(5, -1); + mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch; + mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes; + mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows; + mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols; + + *input_dims = mkldnn_sizes; + } #undef CHECK_BOUNDS - - // MKL-DNN always requires input in NCHW format. - std::vector<int> mkldnn_sizes(4, -1); - mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; - mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; - mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; - mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; - - *input_dims = mkldnn_sizes; } - // Calculate Convolution filter size in MKL-DNN order. MKL-DNN - // requires filter in OIHW format. Function does not return anything. - // But errors arising from sanity checks are returned in context's - // status. - // - // Calculate Convolution filter size in MKL-DNN order. MKL-DNN - // requires filter in OIHW format. Function does not return anything. + // Calculate Convolution filter size in MKL-DNN order. + // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format. + // Function does not return anything. // But errors arising from sanity checks are returned in context's // status. This function differs from GetConvFilterSizeInMklOrder in // parameter for input - it accepts src_shape since Convolution Backward @@ -159,11 +195,13 @@ class MklDnnConvUtil { memory::dims* filter_dims) { CHECK_NOTNULL(filter_dims); - OP_REQUIRES(context_, filter_shape.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", + OP_REQUIRES(context_, filter_shape.dims() == strides_.size(), + errors::InvalidArgument((strides_.size() == 4) + ? "filter must be 4-dimensional: " + : "filter must be 5-dimensional: ", filter_shape.DebugString())); - for (int i = 0; i < 3; i++) { + for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i), std::numeric_limits<int>::max()), @@ -172,32 +210,57 @@ class MklDnnConvUtil { int input_depth = GetTensorDim(input_shape, data_format_, 'C'); - OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter_shape.dim_size(2))); - - // TF filter is always in (rows, cols, in_depth, out_depth) order. - int filter_rows = static_cast<int>(filter_shape.dim_size(0)); - int filter_cols = static_cast<int>(filter_shape.dim_size(1)); - int in_depth = static_cast<int>(filter_shape.dim_size(2)); - int out_depth = static_cast<int>(filter_shape.dim_size(3)); - - // MKL-DNN always needs filter in OIHW format. - // OIHW = (out_depth, in_depth, rows, cols) - std::vector<int> mkldnn_sizes(4, -1); - mkldnn_sizes[MklDnnDims::Dim_O] = out_depth; - mkldnn_sizes[MklDnnDims::Dim_I] = in_depth; - mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; - mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; - - *filter_dims = mkldnn_sizes; + if (strides_.size() == 4) { // Conv2D + OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", + input_depth, " vs ", filter_shape.dim_size(2))); + + // TF filter is always in (rows, cols, in_depth, out_depth) order. + int filter_rows = static_cast<int>(filter_shape.dim_size(0)); + int filter_cols = static_cast<int>(filter_shape.dim_size(1)); + int in_depth = static_cast<int>(filter_shape.dim_size(2)); + int out_depth = static_cast<int>(filter_shape.dim_size(3)); + + // MKL-DNN always needs filter in OIHW format. + // OIHW = (out_depth, in_depth, rows, cols) + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_O] = out_depth; + mkldnn_sizes[MklDnnDims::Dim_I] = in_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; + mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; + + *filter_dims = mkldnn_sizes; + } else { // Conv3D + OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3), + errors::InvalidArgument( + "input and filter must have the same depth: ", + input_depth, " vs ", filter_shape.dim_size(3))); + + // TF filter is always in (planes, rows, cols, in_depth, out_depth) order. + int filter_planes = static_cast<int>(filter_shape.dim_size(0)); + int filter_rows = static_cast<int>(filter_shape.dim_size(1)); + int filter_cols = static_cast<int>(filter_shape.dim_size(2)); + int in_depth = static_cast<int>(filter_shape.dim_size(3)); + int out_depth = static_cast<int>(filter_shape.dim_size(4)); + + // MKL-DNN always needs filter in OIDHW format. + // OIDHW = (out_depth, in_depth, planes, rows, cols) + std::vector<int> mkldnn_sizes(5, -1); + mkldnn_sizes[MklDnnDims3D::Dim3d_O] = out_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_I] = in_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; + mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows; + mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols; + + *filter_dims = mkldnn_sizes; + } } - // Calculate Convolution filter size in MKL-DNN order. MKL-DNN - // requires filter in OIHW format. Function does not return anything. - // But errors arising from sanity checks are returned in context's - // status. + // Calculate Convolution filter size in MKL-DNN order. + // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format. + // Function does not return anything. But errors arising from sanity + // checks are returned in context's status. virtual inline void GetFilterSizeInMklOrder(size_t src_index, size_t filter_index, memory::dims* filter_dims) { @@ -206,8 +269,8 @@ class MklDnnConvUtil { GetTfShape(context_, filter_index), filter_dims); } - // Calculate Bias size for 2D Convolution. Function does not return - // anything, but sets error in context status. + // Calculate Bias size for 2D or 3D Convolution. Function does not + // return anything, but may set an error in context status. virtual inline void GetBiasSizeInMklOrder(size_t bias_index, memory::dims* bias_dims) { const Tensor& bias = MklGetInput(context_, bias_index); @@ -218,73 +281,142 @@ class MklDnnConvUtil { *bias_dims = {static_cast<int>(bias.dim_size(0))}; } - // Function to calculate output and padding size for 2D convolution. + // Function to calculate output and padding size for 2D/3D convolution. // // Calculate output shape of Convolution in MKL-DNN and TensorFlow order. - // MKL-DNN uses NCHW for output order. But TensorFlow output will be in - // NHWC or NCHW format depending on data format. Function also calculates - // left, right, top and bottom pads. Function does not return any status - - // status is returned via context status. + // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order. + // But TensorFlow output will be in NHWC||NCHW(Conv2D) or + // NDHWC||NCDHW(Conv3D) format depending on data format. + // Function also calculates left, right, top and bottom pads. + // Function does not return any status which is set with context status. // // TODO(nhasabni): Add similar function for input and filter in MklShape. virtual inline void GetOutputAndPadSizeInMklOrder( const TensorShape& input_shape, const TensorShape& filter_shape, const memory::dims& strides, const memory::dims& dilations, - memory::dims* output_dims_tf_order, - memory::dims* output_dims_mkl_order, memory::dims* pad_l, - memory::dims* pad_r) { + memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, + memory::dims* pad_l, memory::dims* pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); CHECK_NOTNULL(pad_r); - int input_rows = GetTensorDim(input_shape, data_format_, 'H'); - int input_cols = GetTensorDim(input_shape, data_format_, 'W'); + bool isConv2D = (strides_.size() == 4); + int input_planes, input_rows, input_cols; + if (isConv2D) { + input_rows = GetTensorDim(input_shape, data_format_, 'H'); + input_cols = GetTensorDim(input_shape, data_format_, 'W'); + } else { + input_planes = GetTensorDim(input_shape, data_format_, '0'); + input_rows = GetTensorDim(input_shape, data_format_, '1'); + input_cols = GetTensorDim(input_shape, data_format_, '2'); + } - // The first dimension for filter is rows/height. - int filter_rows = filter_shape.dim_size(0); - // The second dimension for filter is cols/width. - int filter_cols = filter_shape.dim_size(1); + // Filter dimension + // Conv2D: + // First dimension: rows/height. + // Second dimension: cols/width. + // Conv3D: + // First dimension: planes/depth. + // Second dimension: rows/height. + // Third dimension: cols/width. + + int filter_planes, filter_rows, filter_cols; + if (isConv2D) { + filter_rows = filter_shape.dim_size(0); + filter_cols = filter_shape.dim_size(1); + } else { + filter_planes = filter_shape.dim_size(0); + filter_rows = filter_shape.dim_size(1); + filter_cols = filter_shape.dim_size(2); + } - // Stride is vector of 2 elements: {s_r, s_c} - int stride_rows = strides[0]; - int stride_cols = strides[1]; - int dilation_rows = dilations[0]; - int dilation_cols = dilations[1]; + int stride_planes, stride_rows, stride_cols; + int dilation_planes, dilation_rows, dilation_cols; + if (isConv2D) { + // Conv2D stride is a vector of 2 elements: {s_r, s_c} + stride_rows = strides[0]; + stride_cols = strides[1]; + dilation_rows = dilations[0]; + dilation_cols = dilations[1]; + } else { + // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c} + stride_planes = strides[0]; + stride_rows = strides[1]; + stride_cols = strides[2]; + dilation_planes = dilations[0]; + dilation_rows = dilations[1]; + dilation_cols = dilations[2]; + } // Output batch is same as input batch. int out_batch = GetTensorDim(input_shape, data_format_, 'N'); + // Output depth is same as last dimension for filter. - int out_depth = filter_shape.dim_size(3); + int out_depth = filter_shape.dim_size(isConv2D ? 3 : 4); - int64 out_rows = 0, out_cols = 0; + int64 out_rows = 0, out_cols = 0, out_planes = 0; int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right; + int64 pad_D1, pad_D2; + + if (isConv2D) { + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerboseV2( + input_rows, filter_rows, dilation_rows, stride_rows, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerboseV2( + input_cols, filter_cols, dilation_cols, stride_cols, + padding_, &out_cols, &pad_left, &pad_right)); + } else { + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_planes, filter_planes, stride_planes, + padding_, &out_planes, &pad_D1, &pad_D2)); + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_rows, filter_rows, stride_rows, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_cols, filter_cols, stride_cols, + padding_, &out_cols, &pad_left, &pad_right)); + } - OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2(input_rows, filter_rows, - dilation_rows, stride_rows, padding_, - &out_rows, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2(input_cols, filter_cols, - dilation_cols, stride_cols, padding_, - &out_cols, &pad_left, &pad_right)); - - // Tensorflow output is in data_format order. (NHWC or NCHW) + // Tensorflow output is in data_format order. + // Conv2D: NHWC or NCHW + // Conv3D: NDHWC or NCDHW + // MKL-DNN uses asymetric padding. TensorShape out_shape = - ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth); + isConv2D + ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, + out_depth) + : ShapeFromFormat(data_format_, out_batch, + {{out_planes, out_rows, out_cols}}, out_depth); *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); - // MKL-DNN always needs output in NCHW format. - std::vector<int> mkldnn_sizes(4, -1); - mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; - mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; - mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); - mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); - *output_dims_mkl_order = mkldnn_sizes; - - // Now handle padding. MKL-DNN uses asymetric padding. - *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; - *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; + if (isConv2D) { + // For Conv2D, MKL-DNN always needs output in NCHW format. + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; + mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); + mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); + *output_dims_mkl_order = mkldnn_sizes; + + *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; + *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; + } else { + std::vector<int> mkldnn_sizes(5, -1); + mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch; + mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes); + mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows); + mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols); + *output_dims_mkl_order = mkldnn_sizes; + + *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top), + static_cast<int>(pad_left)}; + *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom), + static_cast<int>(pad_right)}; + } } // Calculate output and pad size of forward Convolution operator. @@ -292,10 +424,10 @@ class MklDnnConvUtil { // // Function does not return anything, but sets error in context status. inline void GetOutputAndPadSizeInMklOrder( - size_t src_index, size_t filter_index, - const memory::dims& strides, const memory::dims& dilations, - memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, - memory::dims* pad_l, memory::dims* pad_r) { + size_t src_index, size_t filter_index, const memory::dims& strides, + const memory::dims& dilations, memory::dims* output_dims_tf_order, + memory::dims* output_dims_mkl_order, memory::dims* pad_l, + memory::dims* pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); @@ -304,9 +436,17 @@ class MklDnnConvUtil { auto input_tf_shape = GetTfShape(context_, src_index); auto filter_tf_shape = GetTfShape(context_, filter_index); - OP_REQUIRES(context_, input_tf_shape.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input_tf_shape.DebugString())); + if (strides_.size() == 4) { + // Conv2D + OP_REQUIRES(context_, input_tf_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_tf_shape.DebugString())); + } else { + // Conv3D + OP_REQUIRES(context_, input_tf_shape.dims() == 5, + errors::InvalidArgument("input must be 5-dimensional", + input_tf_shape.DebugString())); + } GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, dilations, output_dims_tf_order, @@ -314,9 +454,11 @@ class MklDnnConvUtil { } // Wrapper function to calculate input, filter, and output sizes of - // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.) - // Function also calculates output shape in Tensorflow order. Additionally, it - // also calculates strides and paddings for 2D Convolution. + // Conv2D/Conv3D in MKL order: + // Conv2D: NCHW for input and output; OIHW for filter. + // Conv3D: NCDHW for input and output; OIDHW for filter. + // Function also calculates output shape in Tensorflow order. + // Additionally, it also calculates strides and paddings. // // Function does not return anything, but sets error in context status. inline void GetConvFwdSizesInMklOrder( @@ -349,16 +491,15 @@ class MklDnnConvUtil { } }; - ///////////////////////////////////////////////////////////////////// -/// Common class that implements Conv2DBackpropFilter and Input +/// Common class that implements ConvBackpropFilter and Input ///////////////////////////////////////////////////////////////////// template <typename Device, class T> -class MklConv2DBackpropCommonOp : public OpKernel { +class MklConvBackpropCommonOp : public OpKernel { public: - ~MklConv2DBackpropCommonOp() {} - explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context) + ~MklConvBackpropCommonOp() {} + explicit MklConvBackpropCommonOp(OpKernelConstruction* context) : OpKernel(context) { string data_format_str; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); @@ -372,20 +513,25 @@ class MklConv2DBackpropCommonOp : public OpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); - int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); - int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); - int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); - int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); - OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), - errors::InvalidArgument( - "Current implementation does not yet support " - "dilations in the batch and depth dimensions.")); - OP_REQUIRES( - context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); + + if (strides_.size() == 4) { + // Check Conv2D dilations + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } |