aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc86
1 files changed, 44 insertions, 42 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 4b6bf92e42..4a47d0463e 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -23,8 +23,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -43,16 +41,18 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
-using mkldnn::prop_kind;
using mkldnn::stream;
+using mkldnn::prop_kind;
-using mkldnn::convolution_backward_data;
-using mkldnn::convolution_direct;
using mkldnn::convolution_forward;
+using mkldnn::convolution_direct;
+using mkldnn::convolution_backward_data;
#endif
namespace tensorflow {
@@ -397,13 +397,12 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Generate input shape.
TensorShape input_shape;
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(input_tensor.shape()),
- errors::InvalidArgument(
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
+ errors::InvalidArgument(
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
input_tensor.dims()));
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- input_tensor.vec<int32>(), &input_shape));
+ input_tensor.vec<int32>(), &input_shape));
TensorShape filter_shape = filter_tensor.shape();
TensorShape obp_shape = obp_tensor.shape();
@@ -415,26 +414,27 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
- conv_utl.GetConvFwdSizesInMklOrder(
- input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
- &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
- &padding_r);
+ conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
+ &fwd_input_dims, &fwd_filter_dims,
+ &strides,
+ &fwd_output_dims_tf_order,
+ &fwd_output_dims,
+ &padding_l, &padding_r);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
- auto fwd_src_md =
- memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
- auto fwd_filter_md =
- memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
- auto fwd_out_md =
- memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
- auto fwd_desc = convolution_forward::desc(
- prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
- fwd_out_md, strides, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_));
+ auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
+ mkl_data_format);
+ auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ memory::format::hwio);
+ auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
+ mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(prop_kind::forward,
+ convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
+ strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
// Allocate output tensor and shape
@@ -475,22 +475,23 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
// Create convolution backward data primitive.
- auto bwd_desc = convolution_backward_data::desc(
- convolution_direct, output.GetOpMemDesc(), filter.GetOpMemDesc(),
- outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_));
+ auto bwd_desc = convolution_backward_data::desc(convolution_direct,
+ output.GetOpMemDesc(), filter.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l,
+ padding_r, TFPaddingToMklDnnPadding(padding_));
- auto bwd_pd = convolution_backward_data::primitive_desc(
- bwd_desc, cpu_engine, fwd_pd);
+ auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc,
+ cpu_engine,
+ fwd_pd);
PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output);
- } catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
+ error_msg));
}
}
@@ -501,8 +502,9 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecutePrimitive(
- const convolution_backward_data::primitive_desc& conv_pd,
- MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
+ const convolution_backward_data::primitive_desc& conv_pd,
+ MklDnnData<T>* filter, MklDnnData<T>* obp,
+ MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
@@ -512,11 +514,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
// Memory for output of convolution. Since we may need reorder on the
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
- bool output_reorder_required =
- output->PrepareReorderToUserMemIfReq(conv_pd.diff_src_primitive_desc());
+ bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
+ conv_pd.diff_src_primitive_desc());
- net.push_back(convolution_backward_data(
- conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
+ net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(),
+ filter->GetOpMem(), output->GetOpMem()));
// Insert reorder primitive in the net for output reorder if reorder is
// required.