aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc78
1 files changed, 40 insertions, 38 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 9080bf7be8..f291281108 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -45,12 +45,12 @@ limitations under the License.
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
-using mkldnn::prop_kind;
using mkldnn::stream;
+using mkldnn::prop_kind;
+using mkldnn::convolution_forward;
using mkldnn::convolution_backward_weights;
using mkldnn::convolution_direct;
-using mkldnn::convolution_forward;
#endif
@@ -463,13 +463,12 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Generate input shapes.
TensorShape filter_shape;
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(filter_tensor.shape()),
- errors::InvalidArgument(
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()),
+ errors::InvalidArgument(
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
filter_tensor.dims()));
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- filter_tensor.vec<int32>(), &filter_shape));
+ filter_tensor.vec<int32>(), &filter_shape));
TensorShape input_shape = input_tensor.shape();
TensorShape obp_shape = obp_tensor.shape();
@@ -481,26 +480,27 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
- conv_utl.GetConvFwdSizesInMklOrder(
- input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
- &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
- &padding_r);
+ conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
+ &fwd_input_dims, &fwd_filter_dims,
+ &strides,
+ &fwd_output_dims_tf_order,
+ &fwd_output_dims,
+ &padding_l, &padding_r);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
- auto fwd_src_md =
- memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
- auto fwd_filter_md =
- memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
- auto fwd_out_md =
- memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
- auto fwd_desc = convolution_forward::desc(
- prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
- fwd_out_md, strides, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_));
+ auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
+ mkl_data_format);
+ auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ memory::format::hwio);
+ auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
+ mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(prop_kind::forward,
+ convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
+ strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
// Allocate output tensor and shape
@@ -537,22 +537,23 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
// Create convolution backward weights primitive.
- auto bwd_desc = convolution_backward_weights::desc(
- convolution_direct, input.GetOpMemDesc(), output.GetOpMemDesc(),
- outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_));
+ auto bwd_desc = convolution_backward_weights::desc(convolution_direct,
+ input.GetOpMemDesc(), output.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l,
+ padding_r, TFPaddingToMklDnnPadding(padding_));
- auto bwd_pd = convolution_backward_weights::primitive_desc(
- bwd_desc, cpu_engine, fwd_pd);
+ auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
+ cpu_engine,
+ fwd_pd);
PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output);
- } catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
+ error_msg));
}
}
@@ -563,8 +564,9 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecutePrimitive(
- const convolution_backward_weights::primitive_desc& conv_pd,
- MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output) {
+ const convolution_backward_weights::primitive_desc& conv_pd,
+ MklDnnData<T>* input, MklDnnData<T>* obp,
+ MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
@@ -575,10 +577,10 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
- conv_pd.diff_weights_primitive_desc());
+ conv_pd.diff_weights_primitive_desc());
- net.push_back(convolution_backward_weights(
- conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem()));
+ net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(),
+ obp->GetOpMem(), output->GetOpMem()));
// Insert reorder primitive in the net for output reorder if reorder is
// required.