aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc81
1 files changed, 53 insertions, 28 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 1401bc65a4..e0706568b1 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -444,6 +444,7 @@ class MklConv2DCustomBackpropFilterOp
~MklConv2DCustomBackpropFilterOp() {}
private:
+ const int kDilationH = 0, kDilationW = 1;
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@@ -492,7 +493,9 @@ class MklConv2DCustomBackpropFilterOp
const convolution_forward::primitive_desc& conv_fwd_pd,
MklDnnData<T>* input, MklDnnData<T>* filter,
MklDnnData<T>* outbackprop, MklDnnData<T>* output,
- Tensor** output_tensor, const memory::dims& strides,
+ Tensor** output_tensor,
+ const memory::dims& strides,
+ const memory::dims& dilations,
const memory::dims& padding_l,
const memory::dims& padding_r, padding_kind padding,
const memory::dims& bwd_output_dims,
@@ -518,31 +521,32 @@ class MklConv2DCustomBackpropFilterOp
bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x);
}
- // Create convolution backward weights primitive.
- auto bwd_desc =
- (biasEnabled && (bias_grad != nullptr))
- ? convolution_backward_weights::desc(
- convolution_direct, input->GetOpMemDesc(),
- output->GetOpMemDesc(), bias_grad->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides, padding_l, padding_r,
- padding)
- : convolution_backward_weights::desc(
- convolution_direct, input->GetOpMemDesc(),
- output->GetOpMemDesc(), outbackprop->GetOpMemDesc(), strides,
- padding_l, padding_r, padding);
-
- auto bwd_pd = convolution_backward_weights::primitive_desc(
- bwd_desc, cpu_engine, conv_fwd_pd);
-
- // Allocate output tensor.
- AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format,
- output_tensor);
-
- CHECK_NOTNULL(*output_tensor);
- // Set buffer handle using allocated output tensor.
- output->SetUsrMemDataHandle(*output_tensor);
-
if (biasEnabled && (bias_grad != nullptr)) {
+ // Create convolution backward weights with bias primitive.
+ // Use dilated convolution in case dilate rates are greater than zero.
+ auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
+ convolution_backward_weights::desc(convolution_direct,
+ input->GetOpMemDesc(), output->GetOpMemDesc(),
+ bias_grad->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(), strides,
+ dilations, padding_l, padding_r, padding) :
+ convolution_backward_weights::desc(convolution_direct,
+ input->GetOpMemDesc(), output->GetOpMemDesc(),
+ bias_grad->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(),
+ strides, padding_l, padding_r, padding);
+ auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
+ cpu_engine,
+ conv_fwd_pd);
+
+ // Allocate output tensor.
+ AllocateOutputTensor(context, bwd_pd, bwd_output_dims,
+ bwd_output_format, output_tensor);
+
+ CHECK_NOTNULL(*output_tensor);
+ // Set buffer handle using allocated output tensor.
+ output->SetUsrMemDataHandle(*output_tensor);
+
// Allocate bias_grad tensor
TensorShape bias_grad_shape({depth});
Tensor* bias_grad_tensor = nullptr;
@@ -553,11 +557,32 @@ class MklConv2DCustomBackpropFilterOp
memory::desc({bias_grad_dims}, MklDnnType<T>(), memory::format::x);
bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor);
bias_grad->SetUsrMemDataHandle(bias_grad_tensor);
- }
- if (biasEnabled && (bias_grad != nullptr)) {
- PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output, bias_grad);
+ PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output,
+ bias_grad);
} else {
+ // Create convolution backward weights primitive.
+ // Use dilated convolution in case dilate rates are greater than zero.
+ auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
+ convolution_backward_weights::desc(convolution_direct,
+ input->GetOpMemDesc(), output->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(), strides,
+ dilations, padding_l, padding_r, padding) :
+ convolution_backward_weights::desc(convolution_direct,
+ input->GetOpMemDesc(), output->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(),
+ strides, padding_l, padding_r, padding);
+ auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
+ cpu_engine,
+ conv_fwd_pd);
+
+ // Allocate output tensor.
+ AllocateOutputTensor(context, bwd_pd, bwd_output_dims,
+ bwd_output_format, output_tensor);
+
+ CHECK_NOTNULL(*output_tensor);
+ // Set buffer handle using allocated output tensor.
+ output->SetUsrMemDataHandle(*output_tensor);
PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output);
}
}