aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.h
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 13:04:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 13:05:05 -0700
commit9c50882415cb87a7eb81048d42401c64bf0617ef (patch)
treec550925b2d9e7f6997ace0e3bb3268572f7066b7 /tensorflow/core/kernels/mkl_conv_ops.h
parent19cafed2ae69ce5cbc4d2b2fc9176fb4c550040f (diff)
parent62191da0819b25906c1b2ed96159cfe36ba00383 (diff)
Merge pull request #21324 from Intel-tensorflow:conv3d
PiperOrigin-RevId: 209032082
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h414
1 files changed, 280 insertions, 134 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 838c06f49d..01cc606f41 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -79,9 +79,16 @@ class MklDnnConvUtil {
// For now we take the stride from the second and third dimensions only
// (we do not support striding on the batch or depth dimension).
CHECK_NOTNULL(strides);
- int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- int stride_cols = GetTensorDim(strides_, data_format_, 'W');
- *strides = {stride_rows, stride_cols};
+ if (strides_.size() == 4) {
+ int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ *strides = {stride_rows, stride_cols};
+ } else if (strides_.size() == 5) {
+ int stride_planes = GetTensorDim(strides_, data_format_, '0');
+ int stride_rows = GetTensorDim(strides_, data_format_, '1');
+ int stride_cols = GetTensorDim(strides_, data_format_, '2');
+ *strides = {stride_planes, stride_rows, stride_cols};
+ }
}
// Calculate Convolution dilations
@@ -89,13 +96,20 @@ class MklDnnConvUtil {
// For now we take the dilation from the second and third dimensions only
// (we do not support dilation on the batch or depth dimension).
CHECK_NOTNULL(dilations);
- int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
- int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
- *dilations = {dilations_rows, dilations_cols};
+ if (dilations_.size() == 4) {
+ int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
+ *dilations = {dilations_rows, dilations_cols};
+ } else if (dilations_.size() == 5) {
+ int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
+ int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
+ *dilations = {dilations_planes, dilations_rows, dilations_cols};
+ }
}
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
- // requires input in NCHW format. Function does not return anything.
+ // requires input in NCHW/NCDHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
@@ -113,40 +127,62 @@ class MklDnnConvUtil {
int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
int input_depth = static_cast<int>(input_depth_raw);
- // Input rows/height
- int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
- CHECK_BOUNDS(input_rows_raw, "Input rows too large");
- int input_rows = static_cast<int>(input_rows_raw);
-
- // Input columns/width
- int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
- CHECK_BOUNDS(input_cols_raw, "Input cols too large");
- int input_cols = static_cast<int>(input_cols_raw);
-
// Input batch
int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
int input_batch = static_cast<int>(input_batch_raw);
+ if (strides_.size() == 4) { // NCHW format for Conv2D
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCHW format Conv2D.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ } else if (strides_.size() == 5) { // NCDHW format for Conv3D
+ // Input planes/third-dimension
+ int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
+ CHECK_BOUNDS(input_planes_raw, "Input depth too large");
+ int input_planes = static_cast<int>(input_planes_raw);
+
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCDHW format for Conv3D.
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ }
#undef CHECK_BOUNDS
-
- // MKL-DNN always requires input in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
-
- *input_dims = mkldnn_sizes;
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
- //
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
+ // Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status. This function differs from GetConvFilterSizeInMklOrder in
// parameter for input - it accepts src_shape since Convolution Backward
@@ -159,11 +195,13 @@ class MklDnnConvUtil {
memory::dims* filter_dims) {
CHECK_NOTNULL(filter_dims);
- OP_REQUIRES(context_, filter_shape.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
+ OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
+ errors::InvalidArgument((strides_.size() == 4)
+ ? "filter must be 4-dimensional: "
+ : "filter must be 5-dimensional: ",
filter_shape.DebugString()));
- for (int i = 0; i < 3; i++) {
+ for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
OP_REQUIRES(context_,
FastBoundsCheck(filter_shape.dim_size(i),
std::numeric_limits<int>::max()),
@@ -172,32 +210,57 @@ class MklDnnConvUtil {
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
- OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", input_depth,
- " vs ", filter_shape.dim_size(2)));
-
- // TF filter is always in (rows, cols, in_depth, out_depth) order.
- int filter_rows = static_cast<int>(filter_shape.dim_size(0));
- int filter_cols = static_cast<int>(filter_shape.dim_size(1));
- int in_depth = static_cast<int>(filter_shape.dim_size(2));
- int out_depth = static_cast<int>(filter_shape.dim_size(3));
-
- // MKL-DNN always needs filter in OIHW format.
- // OIHW = (out_depth, in_depth, rows, cols)
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
-
- *filter_dims = mkldnn_sizes;
+ if (strides_.size() == 4) { // Conv2D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(2)));
+
+ // TF filter is always in (rows, cols, in_depth, out_depth) order.
+ int filter_rows = static_cast<int>(filter_shape.dim_size(0));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(1));
+ int in_depth = static_cast<int>(filter_shape.dim_size(2));
+ int out_depth = static_cast<int>(filter_shape.dim_size(3));
+
+ // MKL-DNN always needs filter in OIHW format.
+ // OIHW = (out_depth, in_depth, rows, cols)
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ } else { // Conv3D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(3)));
+
+ // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
+ int filter_planes = static_cast<int>(filter_shape.dim_size(0));
+ int filter_rows = static_cast<int>(filter_shape.dim_size(1));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(2));
+ int in_depth = static_cast<int>(filter_shape.dim_size(3));
+ int out_depth = static_cast<int>(filter_shape.dim_size(4));
+
+ // MKL-DNN always needs filter in OIDHW format.
+ // OIDHW = (out_depth, in_depth, planes, rows, cols)
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_O] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_I] = in_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ }
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
+ // Function does not return anything. But errors arising from sanity
+ // checks are returned in context's status.
virtual inline void GetFilterSizeInMklOrder(size_t src_index,
size_t filter_index,
memory::dims* filter_dims) {
@@ -206,8 +269,8 @@ class MklDnnConvUtil {
GetTfShape(context_, filter_index), filter_dims);
}
- // Calculate Bias size for 2D Convolution. Function does not return
- // anything, but sets error in context status.
+ // Calculate Bias size for 2D or 3D Convolution. Function does not
+ // return anything, but may set an error in context status.
virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
memory::dims* bias_dims) {
const Tensor& bias = MklGetInput(context_, bias_index);
@@ -218,73 +281,142 @@ class MklDnnConvUtil {
*bias_dims = {static_cast<int>(bias.dim_size(0))};
}
- // Function to calculate output and padding size for 2D convolution.
+ // Function to calculate output and padding size for 2D/3D convolution.
//
// Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function also calculates
- // left, right, top and bottom pads. Function does not return any status -
- // status is returned via context status.
+ // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
+ // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
+ // NDHWC||NCDHW(Conv3D) format depending on data format.
+ // Function also calculates left, right, top and bottom pads.
+ // Function does not return any status which is set with context status.
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void GetOutputAndPadSizeInMklOrder(
const TensorShape& input_shape, const TensorShape& filter_shape,
const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order,
- memory::dims* output_dims_mkl_order, memory::dims* pad_l,
- memory::dims* pad_r) {
+ memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
+ memory::dims* pad_l, memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
- int input_rows = GetTensorDim(input_shape, data_format_, 'H');
- int input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ bool isConv2D = (strides_.size() == 4);
+ int input_planes, input_rows, input_cols;
+ if (isConv2D) {
+ input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ } else {
+ input_planes = GetTensorDim(input_shape, data_format_, '0');
+ input_rows = GetTensorDim(input_shape, data_format_, '1');
+ input_cols = GetTensorDim(input_shape, data_format_, '2');
+ }
- // The first dimension for filter is rows/height.
- int filter_rows = filter_shape.dim_size(0);
- // The second dimension for filter is cols/width.
- int filter_cols = filter_shape.dim_size(1);
+ // Filter dimension
+ // Conv2D:
+ // First dimension: rows/height.
+ // Second dimension: cols/width.
+ // Conv3D:
+ // First dimension: planes/depth.
+ // Second dimension: rows/height.
+ // Third dimension: cols/width.
+
+ int filter_planes, filter_rows, filter_cols;
+ if (isConv2D) {
+ filter_rows = filter_shape.dim_size(0);
+ filter_cols = filter_shape.dim_size(1);
+ } else {
+ filter_planes = filter_shape.dim_size(0);
+ filter_rows = filter_shape.dim_size(1);
+ filter_cols = filter_shape.dim_size(2);
+ }
- // Stride is vector of 2 elements: {s_r, s_c}
- int stride_rows = strides[0];
- int stride_cols = strides[1];
- int dilation_rows = dilations[0];
- int dilation_cols = dilations[1];
+ int stride_planes, stride_rows, stride_cols;
+ int dilation_planes, dilation_rows, dilation_cols;
+ if (isConv2D) {
+ // Conv2D stride is a vector of 2 elements: {s_r, s_c}
+ stride_rows = strides[0];
+ stride_cols = strides[1];
+ dilation_rows = dilations[0];
+ dilation_cols = dilations[1];
+ } else {
+ // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
+ stride_planes = strides[0];
+ stride_rows = strides[1];
+ stride_cols = strides[2];
+ dilation_planes = dilations[0];
+ dilation_rows = dilations[1];
+ dilation_cols = dilations[2];
+ }
// Output batch is same as input batch.
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
+
// Output depth is same as last dimension for filter.
- int out_depth = filter_shape.dim_size(3);
+ int out_depth = filter_shape.dim_size(isConv2D ? 3 : 4);
- int64 out_rows = 0, out_cols = 0;
+ int64 out_rows = 0, out_cols = 0, out_planes = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
+ int64 pad_D1, pad_D2;
+
+ if (isConv2D) {
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_rows, filter_rows, dilation_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_cols, filter_cols, dilation_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ } else {
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_planes, filter_planes, stride_planes,
+ padding_, &out_planes, &pad_D1, &pad_D2));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ }
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_rows, filter_rows,
- dilation_rows, stride_rows, padding_,
- &out_rows, &pad_top, &pad_bottom));
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_cols, filter_cols,
- dilation_cols, stride_cols, padding_,
- &out_cols, &pad_left, &pad_right));
-
- // Tensorflow output is in data_format order. (NHWC or NCHW)
+ // Tensorflow output is in data_format order.
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
+ // MKL-DNN uses asymetric padding.
TensorShape out_shape =
- ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ isConv2D
+ ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
+ out_depth)
+ : ShapeFromFormat(data_format_, out_batch,
+ {{out_planes, out_rows, out_cols}}, out_depth);
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
- // MKL-DNN always needs output in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
- mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
- *output_dims_mkl_order = mkldnn_sizes;
-
- // Now handle padding. MKL-DNN uses asymetric padding.
- *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
- *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ if (isConv2D) {
+ // For Conv2D, MKL-DNN always needs output in NCHW format.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ } else {
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top),
+ static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom),
+ static_cast<int>(pad_right)};
+ }
}
// Calculate output and pad size of forward Convolution operator.
@@ -292,10 +424,10 @@ class MklDnnConvUtil {
//
// Function does not return anything, but sets error in context status.
inline void GetOutputAndPadSizeInMklOrder(
- size_t src_index, size_t filter_index,
- const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
- memory::dims* pad_l, memory::dims* pad_r) {
+ size_t src_index, size_t filter_index, const memory::dims& strides,
+ const memory::dims& dilations, memory::dims* output_dims_tf_order,
+ memory::dims* output_dims_mkl_order, memory::dims* pad_l,
+ memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
@@ -304,9 +436,17 @@ class MklDnnConvUtil {
auto input_tf_shape = GetTfShape(context_, src_index);
auto filter_tf_shape = GetTfShape(context_, filter_index);
- OP_REQUIRES(context_, input_tf_shape.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input_tf_shape.DebugString()));
+ if (strides_.size() == 4) {
+ // Conv2D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input_tf_shape.DebugString()));
+ } else {
+ // Conv3D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input_tf_shape.DebugString()));
+ }
GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape,
strides, dilations, output_dims_tf_order,
@@ -314,9 +454,11 @@ class MklDnnConvUtil {
}
// Wrapper function to calculate input, filter, and output sizes of
- // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
- // Function also calculates output shape in Tensorflow order. Additionally, it
- // also calculates strides and paddings for 2D Convolution.
+ // Conv2D/Conv3D in MKL order:
+ // Conv2D: NCHW for input and output; OIHW for filter.
+ // Conv3D: NCDHW for input and output; OIDHW for filter.
+ // Function also calculates output shape in Tensorflow order.
+ // Additionally, it also calculates strides and paddings.
//
// Function does not return anything, but sets error in context status.
inline void GetConvFwdSizesInMklOrder(
@@ -349,16 +491,15 @@ class MklDnnConvUtil {
}
};
-
/////////////////////////////////////////////////////////////////////
-/// Common class that implements Conv2DBackpropFilter and Input
+/// Common class that implements ConvBackpropFilter and Input
/////////////////////////////////////////////////////////////////////
template <typename Device, class T>
-class MklConv2DBackpropCommonOp : public OpKernel {
+class MklConvBackpropCommonOp : public OpKernel {
public:
- ~MklConv2DBackpropCommonOp() {}
- explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context)
+ ~MklConvBackpropCommonOp() {}
+ explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format_str;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
@@ -372,20 +513,25 @@ class MklConv2DBackpropCommonOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ // Check Conv2D dilations
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
+
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}