aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc86
1 files changed, 42 insertions, 44 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 4a47d0463e..4b6bf92e42 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -23,6 +23,8 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -41,18 +43,16 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
-using mkldnn::stream;
using mkldnn::prop_kind;
+using mkldnn::stream;
-using mkldnn::convolution_forward;
-using mkldnn::convolution_direct;
using mkldnn::convolution_backward_data;
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
#endif
namespace tensorflow {
@@ -397,12 +397,13 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Generate input shape.
TensorShape input_shape;
- OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
- errors::InvalidArgument(
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(input_tensor.shape()),
+ errors::InvalidArgument(
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
input_tensor.dims()));
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- input_tensor.vec<int32>(), &input_shape));
+ input_tensor.vec<int32>(), &input_shape));
TensorShape filter_shape = filter_tensor.shape();
TensorShape obp_shape = obp_tensor.shape();
@@ -414,27 +415,26 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
- conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
- &fwd_input_dims, &fwd_filter_dims,
- &strides,
- &fwd_output_dims_tf_order,
- &fwd_output_dims,
- &padding_l, &padding_r);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
+ &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
+ &padding_r);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
- auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
- mkl_data_format);
- auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
- auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
- mkl_data_format);
- auto fwd_desc = convolution_forward::desc(prop_kind::forward,
- convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
- strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
+ auto fwd_src_md =
+ memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_filter_md =
+ memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
+ auto fwd_out_md =
+ memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
+ fwd_out_md, strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
// Allocate output tensor and shape
@@ -475,23 +475,22 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
// Create convolution backward data primitive.
- auto bwd_desc = convolution_backward_data::desc(convolution_direct,
- output.GetOpMemDesc(), filter.GetOpMemDesc(),
- outbackprop.GetOpMemDesc(), strides, padding_l,
- padding_r, TFPaddingToMklDnnPadding(padding_));
+ auto bwd_desc = convolution_backward_data::desc(
+ convolution_direct, output.GetOpMemDesc(), filter.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
- auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc,
- cpu_engine,
- fwd_pd);
+ auto bwd_pd = convolution_backward_data::primitive_desc(
+ bwd_desc, cpu_engine, fwd_pd);
PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output);
- } catch (mkldnn::error &e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) +
- ", in file " + string(__FILE__) + ":" +
- std::to_string(__LINE__);
- OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
- error_msg));
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
}
}
@@ -502,9 +501,8 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecutePrimitive(
- const convolution_backward_data::primitive_desc& conv_pd,
- MklDnnData<T>* filter, MklDnnData<T>* obp,
- MklDnnData<T>* output) {
+ const convolution_backward_data::primitive_desc& conv_pd,
+ MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
@@ -514,11 +512,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Memory for output of convolution. Since we may need reorder on the
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
- bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
- conv_pd.diff_src_primitive_desc());
+ bool output_reorder_required =
+ output->PrepareReorderToUserMemIfReq(conv_pd.diff_src_primitive_desc());
- net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(),
- filter->GetOpMem(), output->GetOpMem()));
+ net.push_back(convolution_backward_data(
+ conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
// Insert reorder primitive in the net for output reorder if reorder is
// required.