aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h117
1 files changed, 84 insertions, 33 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 9dd88221a8..7ca10db895 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -58,13 +58,16 @@ class MklDnnConvUtil {
protected:
OpKernelContext* context_; // We don't own this.
std::vector<int32> strides_;
+ std::vector<int32> dilations_;
Padding padding_;
TensorFormat data_format_;
public:
MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
- Padding pad, TensorFormat fm)
- : context_(context), strides_(strides), padding_(pad), data_format_(fm) {}
+ Padding pad, TensorFormat fm,
+ const std::vector<int32>& dilations) :
+ context_(context), strides_(strides), padding_(pad),
+ data_format_(fm), dilations_(dilations) {}
virtual ~MklDnnConvUtil() { context_ = nullptr; }
@@ -78,6 +81,16 @@ class MklDnnConvUtil {
*strides = {stride_rows, stride_cols};
}
+ // Calculate Convolution dilations
+ virtual inline void GetDilationsInMklOrder(memory::dims *dilations) {
+ // For now we take the dilation from the second and third dimensions only
+ // (we do not support dilation on the batch or depth dimension).
+ CHECK_NOTNULL(dilations);
+ int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
+ *dilations = {dilations_rows, dilations_cols};
+ }
+
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
// requires input in NCHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
@@ -213,7 +226,8 @@ class MklDnnConvUtil {
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void GetOutputAndPadSizeInMklOrder(
const TensorShape& input_shape, const TensorShape& filter_shape,
- const memory::dims& strides, memory::dims* output_dims_tf_order,
+ const memory::dims& strides, const memory::dims& dilations,
+ memory::dims* output_dims_tf_order,
memory::dims* output_dims_mkl_order, memory::dims* pad_l,
memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
@@ -232,6 +246,8 @@ class MklDnnConvUtil {
// Stride is vector of 2 elements: {s_r, s_c}
int stride_rows = strides[0];
int stride_cols = strides[1];
+ int dilation_rows = dilations[0];
+ int dilation_cols = dilations[1];
// Output batch is same as input batch.
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
@@ -241,11 +257,13 @@ class MklDnnConvUtil {
int64 out_rows = 0, out_cols = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
- OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
- input_rows, filter_rows, stride_rows, padding_,
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(input_rows, filter_rows,
+ dilation_rows, stride_rows, padding_,
&out_rows, &pad_top, &pad_bottom));
- OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
- input_cols, filter_cols, stride_cols, padding_,
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(input_cols, filter_cols,
+ dilation_cols, stride_cols, padding_,
&out_cols, &pad_left, &pad_right));
// Tensorflow output is in data_format order. (NHWC or NCHW)
@@ -271,7 +289,8 @@ class MklDnnConvUtil {
//
// Function does not return anything, but sets error in context status.
inline void GetOutputAndPadSizeInMklOrder(
- size_t src_index, size_t filter_index, const memory::dims& strides,
+ size_t src_index, size_t filter_index,
+ const memory::dims& strides, const memory::dims& dilations,
memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
memory::dims* pad_l, memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
@@ -286,9 +305,9 @@ class MklDnnConvUtil {
errors::InvalidArgument("input must be 4-dimensional",
input_tf_shape.DebugString()));
- GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides,
- output_dims_tf_order, output_dims_mkl_order,
- pad_l, pad_r);
+ GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape,
+ strides, dilations, output_dims_tf_order,
+ output_dims_mkl_order, pad_l, pad_r);
}
// Wrapper function to calculate input, filter, and output sizes of
@@ -300,12 +319,14 @@ class MklDnnConvUtil {
inline void GetConvFwdSizesInMklOrder(
const TensorShape& input_shape, const TensorShape& filter_shape,
memory::dims* input_dims, memory::dims* filter_dims,
- memory::dims* strides, memory::dims* output_dims_tf_order,
+ memory::dims* strides, memory::dims *dilations,
+ memory::dims* output_dims_tf_order,
memory::dims* output_dims_mkl_order, memory::dims* pad_l,
memory::dims* pad_r) {
CHECK_NOTNULL(input_dims);
CHECK_NOTNULL(filter_dims);
CHECK_NOTNULL(strides);
+ CHECK_NOTNULL(dilations);
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
@@ -316,7 +337,9 @@ class MklDnnConvUtil {
GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims);
if (!context_->status().ok()) return;
GetStridesInMklOrder(strides);
- GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides,
+ GetDilationsInMklOrder(dilations);
+ GetOutputAndPadSizeInMklOrder(input_shape, filter_shape,
+ *strides, *dilations,
output_dims_tf_order, output_dims_mkl_order,
pad_l, pad_r);
if (!context_->status().ok()) return;
@@ -344,7 +367,21 @@ class MklConv2DBackpropCommonOp : public OpKernel {
context, (stride_n == 1 && stride_c == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
-
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
@@ -406,15 +443,16 @@ class MklConv2DBackpropCommonOp : public OpKernel {
// By default, all dims are in MKL order. Only dims in TF order
// are those with prefix tf_order.
memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims;
- memory::dims padding_l, padding_r, strides, fwd_output_dims;
+ memory::dims padding_l, padding_r, dilations, strides, fwd_output_dims;
memory::dims fwd_output_dims_tf_order;
// Get forward convolution parameters.
- MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
+ dilations_);
conv_utl.GetConvFwdSizesInMklOrder(
input_tf_shape, filter_tf_shape, &fwd_input_dims, &fwd_filter_dims,
- &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
- &padding_r);
+ &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims,
+ &padding_l, &padding_r);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
@@ -437,10 +475,21 @@ class MklConv2DBackpropCommonOp : public OpKernel {
memory::format::hwio);
// Tensorflow Output of Conv2D is in data_format order.
auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), tf_fmt);
- auto fwd_desc = convolution_forward::desc(
- prop_kind::forward, convolution_direct, fwd_input_md, fwd_filter_md,
- fwd_out_md, strides, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_));
+
+ const int kDilationH = 0, kDilationW = 1;
+ dilations[kDilationH] -= 1;
+ dilations[kDilationW] -= 1;
+ auto fwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0)?
+ convolution_forward::desc(prop_kind::forward,
+ convolution_direct, fwd_input_md,
+ fwd_filter_md, fwd_out_md,
+ strides, dilations, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_)) :
+ convolution_forward::desc(prop_kind::forward,
+ convolution_direct, fwd_input_md,
+ fwd_filter_md, fwd_out_md,
+ strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
// Create memory for user data. Describe how the inputs and outputs of
@@ -485,8 +534,9 @@ class MklConv2DBackpropCommonOp : public OpKernel {
// Operator-specific call to create and execute primitive.
CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter,
- &outbackprop, &output, &output_tensor, strides, padding_l,
- padding_r, TFPaddingToMklDnnPadding(padding_),
+ &outbackprop, &output, &output_tensor,
+ strides, dilations, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_),
bwd_output_dims, bwd_output_format);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
@@ -535,20 +585,21 @@ class MklConv2DBackpropCommonOp : public OpKernel {
virtual memory::format GetOutputFormat(const memory::format data_format) = 0;
/// Create and execute the primitive storing output in the output_tensor.
- virtual void CreatePrimitive(
- OpKernelContext* context, const engine& cpu_engine,
- const convolution_forward::primitive_desc& conv_fwd_pd,
- MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop,
- MklDnnData<T>* output, Tensor** output_tensor,
- const memory::dims& strides, const memory::dims& padding_l,
- const memory::dims& padding_r, padding_kind padding,
- const memory::dims& bwd_output_dims,
- memory::format bwd_output_format) = 0;
+ virtual void CreatePrimitive(OpKernelContext* context,
+ const engine& cpu_engine,
+ const convolution_forward::primitive_desc& conv_fwd_pd,
+ MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop,
+ MklDnnData<T>* output, Tensor** output_tensor, const memory::dims& strides,
+ const memory::dims& dilations, const memory::dims& padding_l,
+ const memory::dims& padding_r, padding_kind padding,
+ const memory::dims& bwd_output_dims,
+ memory::format bwd_output_format) = 0;
// Get the data_format {NCHW, NHWC}
TensorFormat GetTFDataFormat() { return data_format_; }
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;