aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index eeed009531..d203c04934 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -369,6 +369,7 @@ class MklConv2DCustomBackpropInputOp
private:
const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0,
kInputIndex_OutBackProp = 2;
+ const int kDilationH = 0, kDilationW = 1;
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@@ -419,7 +420,9 @@ class MklConv2DCustomBackpropInputOp
const convolution_forward::primitive_desc& conv_fwd_pd,
MklDnnData<T>* input, MklDnnData<T>* filter,
MklDnnData<T>* outbackprop, MklDnnData<T>* output,
- Tensor** output_tensor, const memory::dims& strides,
+ Tensor** output_tensor,
+ const memory::dims& strides,
+ const memory::dims& dilations,
const memory::dims& padding_l,
const memory::dims& padding_r, padding_kind padding,
const memory::dims& bwd_output_dims,
@@ -432,9 +435,16 @@ class MklConv2DCustomBackpropInputOp
CHECK_NOTNULL(output_tensor);
// Create convolution backward data primitive.
- auto bwd_desc = convolution_backward_data::desc(
- convolution_direct, output->GetOpMemDesc(), filter->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides, padding_l, padding_r, padding);
+ // Use dilated convolution in case dilate rates are greater than zero.
+ auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
+ convolution_backward_data::desc(convolution_direct,
+ output->GetOpMemDesc(), filter->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(), strides,
+ dilations, padding_l, padding_r, padding):
+ convolution_backward_data::desc(convolution_direct,
+ output->GetOpMemDesc(), filter->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(),
+ strides, padding_l, padding_r, padding);
auto bwd_pd = convolution_backward_data::primitive_desc(
bwd_desc, cpu_engine, conv_fwd_pd);