aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc78
1 files changed, 38 insertions, 40 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index f291281108..9080bf7be8 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -45,12 +45,12 @@ limitations under the License.
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
-using mkldnn::stream;
using mkldnn::prop_kind;
+using mkldnn::stream;
-using mkldnn::convolution_forward;
using mkldnn::convolution_backward_weights;
using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
#endif
@@ -463,12 +463,13 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Generate input shapes.
TensorShape filter_shape;
- OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()),
- errors::InvalidArgument(
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(filter_tensor.shape()),
+ errors::InvalidArgument(
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
filter_tensor.dims()));
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- filter_tensor.vec<int32>(), &filter_shape));
+ filter_tensor.vec<int32>(), &filter_shape));
TensorShape input_shape = input_tensor.shape();
TensorShape obp_shape = obp_tensor.shape();
@@ -480,27 +481,26 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
- conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
- &fwd_input_dims, &fwd_filter_dims,
- &strides,
- &fwd_output_dims_tf_order,
- &fwd_output_dims,
- &padding_l, &padding_r);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
+ &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
+ &padding_r);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
- auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
- mkl_data_format);
- auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
- auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
- mkl_data_format);
- auto fwd_desc = convolution_forward::desc(prop_kind::forward,
- convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
- strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
+ auto fwd_src_md =
+ memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_filter_md =
+ memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
+ auto fwd_out_md =
+ memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
+ fwd_out_md, strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
// Allocate output tensor and shape
@@ -537,23 +537,22 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
// Create convolution backward weights primitive.
- auto bwd_desc = convolution_backward_weights::desc(convolution_direct,
- input.GetOpMemDesc(), output.GetOpMemDesc(),
- outbackprop.GetOpMemDesc(), strides, padding_l,
- padding_r, TFPaddingToMklDnnPadding(padding_));
+ auto bwd_desc = convolution_backward_weights::desc(
+ convolution_direct, input.GetOpMemDesc(), output.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
- auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
- cpu_engine,
- fwd_pd);
+ auto bwd_pd = convolution_backward_weights::primitive_desc(
+ bwd_desc, cpu_engine, fwd_pd);
PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output);
- } catch (mkldnn::error &e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) +
- ", in file " + string(__FILE__) + ":" +
- std::to_string(__LINE__);
- OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
- error_msg));
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
}
}
@@ -564,9 +563,8 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecutePrimitive(
- const convolution_backward_weights::primitive_desc& conv_pd,
- MklDnnData<T>* input, MklDnnData<T>* obp,
- MklDnnData<T>* output) {
+ const convolution_backward_weights::primitive_desc& conv_pd,
+ MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
@@ -577,10 +575,10 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
- conv_pd.diff_weights_primitive_desc());
+ conv_pd.diff_weights_primitive_desc());
- net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(),
- obp->GetOpMem(), output->GetOpMem()));
+ net.push_back(convolution_backward_weights(
+ conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem()));
// Insert reorder primitive in the net for output reorder if reorder is
// required.