aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tatiana Shpeisman <shpeisman@google.com>2018-07-03 18:09:35 -0700
committerGravatar GitHub <noreply@github.com>2018-07-03 18:09:35 -0700
commitb2fe2a874bade4782aaca5c44bf29e7ff6c39200 (patch)
tree77a6f54a3e40bd8be2a8fe005fed56f27751c044
parent3b538660b1eb22e52ad455a17e01598508373969 (diff)
parent56150c9829b79c2249a4b90087ce25b1e6624f0b (diff)
Merge pull request #19399 from Intel-tensorflow/primreuse_conv_bwd
INTEL-MKL: Enhance Mkl conv2d backward (filter and input) ops with primitive reuse
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc659
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc476
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h222
3 files changed, 932 insertions, 425 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 356eed8b67..4e80f5acce 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -54,9 +54,311 @@ using mkldnn::stream;
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
-
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_ML
+
+struct MklConvBwdFilterParams {
+ memory::dims src_dims;
+ memory::dims diff_filter_dims;
+ memory::dims diff_bias_dims;
+ memory::dims diff_dst_dims;
+ memory::dims strides;
+ memory::dims dilations;
+ memory::dims padding_left;
+ memory::dims padding_right;
+ padding_kind padding;
+
+ MklConvBwdFilterParams(memory::dims src_dims,
+ memory::dims diff_filter_dims, memory::dims diff_bias_dims,
+ memory::dims diff_dst_dims, memory::dims strides,
+ memory::dims dilations, memory::dims padding_left,
+ memory::dims padding_right, padding_kind padding) :
+ src_dims(src_dims), diff_filter_dims(diff_filter_dims),
+ diff_bias_dims(diff_bias_dims), diff_dst_dims(diff_dst_dims),
+ strides(strides), dilations(dilations),
+ padding_left(padding_left), padding_right(padding_right),
+ padding(padding) {
+ }
+};
+
+template <typename T>
+class MklConv2DBwdFilterPrimitive : public MklPrimitive {
+ public:
+ explicit MklConv2DBwdFilterPrimitive(
+ const MklConvBwdFilterParams& convBwdFilterDims) :
+ cpu_engine_(engine::cpu, 0) {
+ context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
+ // create conv primitive
+ if (context_.conv_bwd_filter == nullptr) {
+ Setup(convBwdFilterDims);
+ }
+ }
+
+ ~MklConv2DBwdFilterPrimitive() {}
+
+ // Convolution backward weights with bias
+ // src_data: input data buffer of src
+ // diff_filter_data: output data buffer of diff_filter
+ // diff_bias_data: output data buffer of diff_bias
+ // diff_dst_data: input data buffer of diff_dst
+ void Execute(const T* src_data, const T* diff_filter_data,
+ const T* diff_bias_data, const T* diff_dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.diff_filter_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_filter_data)));
+ context_.diff_bias_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_bias_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+
+ context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
+
+ context_.src_mem->set_data_handle(DummyData);
+ context_.diff_filter_mem->set_data_handle(DummyData);
+ context_.diff_bias_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ return;
+ }
+
+ // Convolution backward weights without bias
+ // src_data: input data buffer of src
+ // diff_filter_data: output data buffer of diff_filter
+ // diff_dst_data: input data buffer of diff_dst
+ void Execute(const T* src_data,
+ const T* diff_filter_data, const T* diff_dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.diff_filter_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_filter_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+
+ context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
+
+ context_.src_mem->set_data_handle(DummyData);
+ context_.diff_filter_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ return;
+ }
+
+ memory::format GetSrcMemoryFormat() const {
+ return context_.src_fmt;
+ }
+
+ memory::format GetDiffDstMemoryFormat() const {
+ return context_.diff_dst_fmt;
+ }
+
+ memory::format GetDiffFilterMemoryFormat() const {
+ return context_.diff_filter_fmt;
+ }
+
+ // convolution primitive
+ std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
+ GetPrimitiveDesc() const {
+ return context_.bwd_filter_pd;
+ }
+
+ private:
+ // Primitive reuse context for Conv2D bwd filter op
+ struct ConvBwdFilterContext {
+ // expected memory format for this primitive instance
+ memory::format src_fmt;
+ memory::format diff_dst_fmt;
+ memory::format diff_filter_fmt;
+
+ // convolution bwd input primitive
+ std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
+ bwd_filter_pd;
+ std::shared_ptr<mkldnn::primitive> conv_bwd_filter;
+
+ // MKLDNN memory
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> diff_filter_mem;
+ std::shared_ptr<mkldnn::memory> diff_bias_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::convolution_backward_weights::desc> bwd_filter_desc;
+ std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
+
+ // memory desc: forward & backward can share same memory desc
+ std::shared_ptr<mkldnn::memory::desc> src_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_filter_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_bias_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
+
+ // MKL pipeline
+ std::shared_ptr<mkldnn::stream> bwd_filter_stream;
+ std::vector<mkldnn::primitive> bwd_filter_primitives;
+
+ ConvBwdFilterContext() :
+ src_fmt(memory::format::any),
+ diff_dst_fmt(memory::format::any),
+ diff_filter_fmt(memory::format::any),
+ src_mem(nullptr), diff_filter_mem(nullptr),
+ diff_bias_mem(nullptr), diff_dst_mem(nullptr),
+ bwd_filter_desc(nullptr), fwd_desc(nullptr), fwd_pd(nullptr),
+ src_md(nullptr), diff_filter_md(nullptr),
+ diff_bias_md(nullptr), diff_dst_md(nullptr),
+ bwd_filter_stream(nullptr) {
+ }
+ };
+
+ // Setup Conv2d backward filter (weights) primitives.
+ void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
+ // create memory descriptors for convolution data w/ no specified format
+ context_.src_md.reset(new memory::desc({convBwdFilterDims.src_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ context_.diff_dst_md.reset(new memory::desc(
+ {convBwdFilterDims.diff_dst_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ context_.diff_filter_md.reset(new memory::desc(
+ {convBwdFilterDims.diff_filter_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ if (!convBwdFilterDims.diff_bias_dims.empty())
+ context_.diff_bias_md.reset(new memory::desc(
+ {convBwdFilterDims.diff_bias_dims},
+ MklDnnType<T>(), memory::format::x));
+
+ // create a convolution
+ if (!convBwdFilterDims.diff_bias_dims.empty()) {
+ context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
+ convolution_direct, *context_.src_md, *context_.diff_filter_md,
+ *context_.diff_bias_md, *context_.diff_dst_md,
+ convBwdFilterDims.strides, convBwdFilterDims.dilations,
+ convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
+ convBwdFilterDims.padding));
+ } else {
+ context_.bwd_filter_desc.reset(
+ new convolution_backward_weights::desc(
+ convolution_direct, *context_.src_md, *context_.diff_filter_md,
+ *context_.diff_dst_md, convBwdFilterDims.strides,
+ convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
+ convBwdFilterDims.padding_right, convBwdFilterDims.padding));
+ }
+
+ // create fwd primitive_desc
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct,
+ *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md,
+ convBwdFilterDims.strides,
+ convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
+ convBwdFilterDims.padding_right, convBwdFilterDims.padding));
+ context_.fwd_pd.reset(new convolution_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+
+ // create backward conv primitive_desc
+ context_.bwd_filter_pd.reset(
+ new convolution_backward_weights::primitive_desc(
+ *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
+
+ // store the expected memory format
+ auto bwd_filter_pd = context_.bwd_filter_pd.get();
+ context_.src_fmt = static_cast<mkldnn::memory::format>(
+ bwd_filter_pd->src_primitive_desc().desc().data.format);
+ context_.diff_filter_fmt = static_cast<mkldnn::memory::format>(
+ bwd_filter_pd->diff_weights_primitive_desc().desc().data.format);
+ context_.diff_dst_fmt = static_cast<mkldnn::memory::format>(
+ bwd_filter_pd->diff_dst_primitive_desc().desc().data.format);
+
+ // create memory primitive based on dummy data
+ context_.src_mem.reset(new memory(
+ bwd_filter_pd->src_primitive_desc(), DummyData));
+ context_.diff_filter_mem.reset(new memory(
+ bwd_filter_pd->diff_weights_primitive_desc(), DummyData));
+ context_.diff_dst_mem.reset(new memory(
+ bwd_filter_pd->diff_dst_primitive_desc(), DummyData));
+
+ // create convolution primitive and add it to net
+ if (!convBwdFilterDims.diff_bias_dims.empty()) {
+ context_.diff_bias_mem.reset(new memory(
+ {{{convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(),
+ memory::format::x}, cpu_engine_}, DummyData));
+ context_.conv_bwd_filter.reset(new convolution_backward_weights(
+ *context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
+ *context_.diff_filter_mem, *context_.diff_bias_mem));
+ } else {
+ context_.conv_bwd_filter.reset(new convolution_backward_weights(
+ *context_.bwd_filter_pd, *context_.src_mem,
+ *context_.diff_dst_mem, *context_.diff_filter_mem));
+ }
+
+ context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter);
+ }
+
+ struct ConvBwdFilterContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklConv2DBwdFilterPrimitive<T>* Get(
+ const MklConvBwdFilterParams& convBwdFilterDims) {
+ MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr;
+
+ // look into the pool for reusable primitive
+ conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> (
+ MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter(
+ convBwdFilterDims));
+
+ if (conv2d_bwd_filter == nullptr) {
+ conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>(
+ convBwdFilterDims);
+ MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter(
+ convBwdFilterDims, conv2d_bwd_filter);
+ }
+ return conv2d_bwd_filter;
+ }
+
+
+ private:
+ MklConv2DBwdFilterPrimitiveFactory() {}
+ ~MklConv2DBwdFilterPrimitiveFactory() {}
+
+ static MklConv2DBwdFilterPrimitiveFactory& GetInstance() {
+ static MklConv2DBwdFilterPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ static std::string CreateKey(
+ const MklConvBwdFilterParams& convBwdFilterDims) {
+ std::string prefix = "conv2d_bwd_filter";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(convBwdFilterDims.src_dims);
+ key_creator.AddAsKey(convBwdFilterDims.diff_filter_dims);
+ key_creator.AddAsKey(convBwdFilterDims.diff_bias_dims);
+ key_creator.AddAsKey(convBwdFilterDims.diff_dst_dims);
+ key_creator.AddAsKey(convBwdFilterDims.strides);
+ key_creator.AddAsKey(convBwdFilterDims.dilations);
+ key_creator.AddAsKey(convBwdFilterDims.padding_left);
+ key_creator.AddAsKey(convBwdFilterDims.padding_right);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetConv2dBwdFilter(
+ const MklConvBwdFilterParams& convBwdFilterDims) {
+ std::string key = CreateKey(convBwdFilterDims);
+ return this->GetOp(key);
+ }
+
+ void SetConv2dBwdFilter(
+ const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
+ std::string key = CreateKey(convBwdFilterDims);
+ this->SetOp(key, op);
+ }
+};
+
+#endif
+
#ifdef INTEL_MKL_ML
template <typename Device, class T>
@@ -442,11 +744,213 @@ class MklConv2DCustomBackpropFilterOp
: public MklConv2DBackpropCommonOp<Device, T> {
public:
explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {}
+ : MklConv2DBackpropCommonOp<Device, T>(context) {
+ }
+
~MklConv2DCustomBackpropFilterOp() {}
+ void Compute(OpKernelContext* context) {
+ try {
+ MklDnnData<T> src(&cpu_engine_);
+ MklDnnData<T> diff_dst(&cpu_engine_);
+ MklDnnData<T> diff_filter(&cpu_engine_); // output
+
+ // Input tensors
+ const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
+ const Tensor& src_tensor = MklGetInput(context, kInputIdx);
+ const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
+ const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
+
+ MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
+ GetMklShape(context, kInputIdx, &src_mkl_shape);
+ GetMklShape(context, kFilterIdx, &filter_mkl_shape);
+ GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
+ // Allow operator-specific sanity checking of shapes.
+ ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
+
+ // Allow operator-specific generation of shapes.
+ // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // tensor containing shape of filter. So filter.shape() is not
+ // a correct way to get filter shape. These operator-specific calls
+ // allow this class to handle this case.
+ TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
+ TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
+ TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
+
+ // Corner cases: output with 0 elements and 0 batch size.
+ Tensor* diff_filter_tensor = nullptr;
+ if (src_tf_shape.num_elements() == 0 ||
+ filter_tf_shape.num_elements() == 0 ||
+ diff_dst_tf_shape.num_elements() == 0) {
+ MklDnnShape diff_filter_mkl_shape;
+ diff_filter_mkl_shape.SetMklTensor(false);
+ TensorShape diff_filter_tf_shape = GetOutputTfShape(
+ src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
+ const int kOutputIdx = 0;
+ AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ CHECK_NOTNULL(diff_filter_tensor);
+
+ // if output tensor has more than 0 elements, we need to 0 them out.
+ auto diff_filter_data = diff_filter_tensor->flat<T>().data();
+ for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) {
+ diff_filter_data[i] = 0;
+ }
+ return;
+ }
+
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with prefix tf_order.
+ memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
+ memory::dims padding_left, padding_right, dilations,
+ strides, fwd_dst_dims;
+ memory::dims fwd_dst_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
+ this->data_format_, this->dilations_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
+ &strides, &dilations, &fwd_dst_dims_tf_order,
+ &fwd_dst_dims, &padding_left, &padding_right);
+ if (!context->status().ok()) return;
+
+ auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto fwd_src_md =
+ src_mkl_shape.IsMklTensor()
+ ? src_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_src_dims, MklDnnType<T>(), tf_fmt);
+
+ conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ if (!context->status().ok()) return;
+
+ auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor()
+ ? diff_dst_mkl_shape.GetMklLayout()
+ : memory::desc(diff_dst_dims,
+ MklDnnType<T>(), tf_fmt);
+
+ memory::dims diff_bias_dims = {};
+ int64 depth = 0;
+ if (biasEnabled) {
+ TensorShape obp_tf_shape = GetTfShape(context, 2);
+ depth = (this->data_format_ == FORMAT_NCHW)
+ ? obp_tf_shape.dim_size(1)
+ : obp_tf_shape.dim_size(3);
+ diff_bias_dims = {static_cast<int>(depth)};
+ }
+
+ dilations[kDilationH] -= 1;
+ dilations[kDilationW] -= 1;
+
+ MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
+ diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
+ padding_right, TFPaddingToMklDnnPadding(this->padding_));
+ conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get(
+ convBwdFilterDims);
+ auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc();
+
+ // allocate output tensors: diff_fitler and diff_bias (w bias)
+ auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
+
+ // diff_filter
+ MklDnnShape diff_filter_mkl_shape;
+ diff_filter_mkl_shape.SetMklTensor(false);
+ // output_dims_mkl_order is in OIHW format.
+ TensorShape diff_filter_tf_shape(
+ {bwd_output_dims[MklDnnDims::Dim_H],
+ bwd_output_dims[MklDnnDims::Dim_W],
+ bwd_output_dims[MklDnnDims::Dim_I],
+ bwd_output_dims[MklDnnDims::Dim_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+
+ Tensor* diff_bias_tensor = nullptr;
+ if (biasEnabled) {
+ TensorShape diff_bias_shape({depth});
+ AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor);
+ }
+
+ // check if src and diff_dst need reorder
+ std::vector<primitive> net;
+ T *src_data = nullptr;
+ if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) {
+ src.SetUsrMem(fwd_src_md, &src_tensor);
+ src.CheckReorderToOpMem(
+ bwd_filter_pd->src_primitive_desc(), &net);
+ src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
+ } else {
+ src_data = static_cast<T*>(const_cast<T*>(
+ src_tensor.flat<T>().data()));
+ }
+
+ T *diff_dst_data = nullptr;
+ if (diff_dst_md.data.format !=
+ conv2d_bwd_filter->GetDiffDstMemoryFormat()) {
+ diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ diff_dst.CheckReorderToOpMem(
+ bwd_filter_pd->diff_dst_primitive_desc(), &net);
+ diff_dst_data = static_cast<T*>(
+ diff_dst.GetOpMem().get_data_handle());
+ } else {
+ diff_dst_data = static_cast<T*>(const_cast<T*>(
+ diff_dst_tensor.flat<T>().data()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+
+ // For backward filter, convert diff_filter back to Tensorflow layout
+ // Here we prepare to reorder op memory back to user memory
+ bool diff_filter_reorder_required = false;
+ T *diff_filter_data = nullptr;
+ if (GetOutputFormat(tf_fmt) !=
+ conv2d_bwd_filter->GetDiffFilterMemoryFormat()) {
+ // Allocate diff filter tensor as Tensorflow layout
+ diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
+ diff_filter_tensor);
+ diff_filter_reorder_required = true;
+ diff_filter.PrepareReorderToUserMemIfReq(
+ bwd_filter_pd->diff_weights_primitive_desc());
+ diff_filter_data = static_cast<T*>(
+ diff_filter.GetOpMem().get_data_handle());
+ } else {
+ diff_filter_data = static_cast<T*>(const_cast<T*>(
+ diff_filter_tensor->flat<T>().data()));
+ }
+
+ // Execute convolution filter bwd
+ if (biasEnabled) {
+ T* diff_bias_data = static_cast<T*>(const_cast<T*>(
+ diff_bias_tensor->flat<T>().data()));
+ conv2d_bwd_filter->Execute(src_data, diff_filter_data,
+ diff_bias_data, diff_dst_data);
+ } else {
+ conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
+ }
+
+ // Reorder diff_filter back to Tensorflow layout if necessary
+ if (diff_filter_reorder_required) {
+ std::vector<primitive> net;
+ diff_filter.InsertReorderToUserMem(&net);
+ stream(stream::kind::eager).submit(net).wait();
+ }
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
private:
+ const int kInputIndex_Filter = 1;
+ const int kInputIndex_InputSizes = 0;
const int kDilationH = 0, kDilationW = 1;
+ engine cpu_engine_ = engine(engine::cpu, 0);
+
+ // Validate input shapes.
+ // Function asserts that input shapes are valid.
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@@ -454,141 +958,44 @@ class MklConv2DCustomBackpropFilterOp
<< "Conv2DBackpropFilter: filter should not be in MKL Layout";
}
- size_t GetInputTensorIndexWithSizes() { return 1; /* filter index */ }
-
+ // Get TensorFlow shape of input tensor.
TensorShape MakeInputTfShape(OpKernelContext* context,
const Tensor& input_tensor) {
size_t input_idx = 0;
return GetTfShape(context, input_idx);
}
+ // Get TensorFlow shape of filter tensor.
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
TensorShape filter_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true);
CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(),
- &filter_tf_shape)
- .ok(),
- true);
+ &filter_tf_shape).ok(), true);
return filter_tf_shape;
}
+ // Get Tensorflow shape of output tensor (diff_filter),
+ // which is same as shape of filter.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
- // Shape of output of Conv2DBackpropFilter is same as shape of filter.
return filter_shape;
}
+ // Get the shape of output (diff_filter) in MKL-DNN order.
+ // Computes shape of output from input shape (fwd_input_dims)
+ // and filter shape (fwd_filter_dims).
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
- // Shape of output of Conv2DBackpropFilter is same as shape of filter.
return fwd_filter_dims;
}
+ // Output layout is Tensorflow's filter layout (HWIO).
memory::format GetOutputFormat(const memory::format data_format) {
- // Output layout is Tensorflow's filter layout (HWIO).
return memory::format::hwio;
}
- void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine,
- const convolution_forward::primitive_desc& conv_fwd_pd,
- MklDnnData<T>* input, MklDnnData<T>* filter,
- MklDnnData<T>* outbackprop, MklDnnData<T>* output,
- Tensor** output_tensor,
- const memory::dims& strides,
- const memory::dims& dilations,
- const memory::dims& padding_l,
- const memory::dims& padding_r, padding_kind padding,
- const memory::dims& bwd_output_dims,
- memory::format bwd_output_format) {
- CHECK_NOTNULL(context);
- CHECK_NOTNULL(input);
- CHECK_NOTNULL(filter);
- CHECK_NOTNULL(outbackprop);
- CHECK_NOTNULL(output);
- CHECK_NOTNULL(output_tensor);
-
- MklDnnData<T>* bias_grad = nullptr;
- int depth = 0;
- if (biasEnabled) {
- // Data structure for bias_grad
- bias_grad = new MklDnnData<T>(&cpu_engine);
- TensorShape obp_tf_shape = GetTfShape(context, 2);
- depth = (MklConv2DBackpropCommonOp<Device, T>::GetTFDataFormat() ==
- FORMAT_NCHW)
- ? obp_tf_shape.dim_size(1)
- : obp_tf_shape.dim_size(3);
- memory::dims bias_grad_dims = {depth};
- bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x);
- }
-
- if (biasEnabled && (bias_grad != nullptr)) {
- // Create convolution backward weights with bias primitive.
- // Use dilated convolution in case dilate rates are greater than zero.
- auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- bias_grad->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides,
- dilations, padding_l, padding_r, padding) :
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- bias_grad->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(),
- strides, padding_l, padding_r, padding);
- auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
- cpu_engine,
- conv_fwd_pd);
-
- // Allocate output tensor.
- AllocateOutputTensor(context, bwd_pd, bwd_output_dims,
- bwd_output_format, output_tensor);
-
- CHECK_NOTNULL(*output_tensor);
- // Set buffer handle using allocated output tensor.
- output->SetUsrMemDataHandle(*output_tensor);
-
- // Allocate bias_grad tensor
- TensorShape bias_grad_shape({depth});
- Tensor* bias_grad_tensor = nullptr;
- AllocateBiasGradTensor(context, bias_grad_shape, &bias_grad_tensor);
- memory::dims bias_grad_dims = {depth};
- // Since Bias is 1D, we use format::x from MKLDNN to represent it.
- auto bias_grad_md =
- memory::desc({bias_grad_dims}, MklDnnType<T>(), memory::format::x);
- bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor);
- bias_grad->SetUsrMemDataHandle(bias_grad_tensor);
-
- PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output,
- bias_grad);
- } else {
- // Create convolution backward weights primitive.
- // Use dilated convolution in case dilate rates are greater than zero.
- auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides,
- dilations, padding_l, padding_r, padding) :
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(),
- strides, padding_l, padding_r, padding);
- auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
- cpu_engine,
- conv_fwd_pd);
-
- // Allocate output tensor.
- AllocateOutputTensor(context, bwd_pd, bwd_output_dims,
- bwd_output_format, output_tensor);
-
- CHECK_NOTNULL(*output_tensor);
- // Set buffer handle using allocated output tensor.
- output->SetUsrMemDataHandle(*output_tensor);
- PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output);
- }
- }
-
// Allocate output tensor.
void AllocateOutputTensor(
OpKernelContext* context,
@@ -623,40 +1030,8 @@ class MklConv2DCustomBackpropFilterOp
MklDnnShape bias_grad_mkl_shape;
bias_grad_mkl_shape.SetMklTensor(false);
- AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape,
- bias_grad_mkl_shape);
- }
-
- // Prepare and execute net - checks for input and output reorders.
- void PrepareAndExecutePrimitive(
- const convolution_backward_weights::primitive_desc& conv_pd,
- MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output,
- MklDnnData<T>* bias_grad = nullptr) {
- // Create reorders between user layout and MKL layout if it is needed and
- // add it to the net before convolution.
- std::vector<primitive> net;
- input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
- obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
-
- // For BackpropFilter, we convert the output tensor back in Tensorflow
- // layout.
- bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
- conv_pd.diff_weights_primitive_desc());
-
- if (biasEnabled && (bias_grad != nullptr)) {
- net.push_back(convolution_backward_weights(
- conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem(),
- bias_grad->GetOpMem()));
- } else {
- net.push_back(convolution_backward_weights(
- conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem()));
- }
-
- if (output_reorder_required) {
- output->InsertReorderToUserMem(&net);
- }
-
- stream(stream::kind::eager).submit(net).wait();
+ AllocateOutputSetMklShape(context, 1, bias_grad_tensor,
+ bias_grad_shape, bias_grad_mkl_shape);
}
};
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 21b18f9119..0af4568b47 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -55,9 +55,246 @@ using mkldnn::stream;
#endif
namespace tensorflow {
-
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_ML
+
+/// utility classes enabling primitive reuse for backward conv2d ops.
+struct MklConvBwdInputParams {
+ memory::dims diff_src_dims;
+ memory::dims filter_dims;
+ memory::dims diff_dst_dims;
+ memory::dims strides;
+ memory::dims dilations;
+ memory::dims padding_left;
+ memory::dims padding_right;
+ padding_kind padding;
+
+ MklConvBwdInputParams(memory::dims diff_src_dims,
+ memory::dims filter_dims, memory::dims diff_dst_dims,
+ memory::dims strides, memory::dims dilations,
+ memory::dims padding_left, memory::dims padding_right,
+ padding_kind padding) :
+ diff_src_dims(diff_src_dims), filter_dims(filter_dims),
+ diff_dst_dims(diff_dst_dims), strides(strides),
+ dilations(dilations), padding_left(padding_left),
+ padding_right(padding_right), padding(padding) {
+ }
+};
+
+template <typename T>
+class MklConv2DBwdInputPrimitive : public MklPrimitive {
+ public:
+ explicit MklConv2DBwdInputPrimitive(
+ const MklConvBwdInputParams& convBwdInputDims) :
+ cpu_engine_(engine::cpu, 0) {
+ context_.bwd_input_stream.reset(new stream(stream::kind::eager));
+
+ // create conv primitive
+ if (context_.conv_bwd_input == nullptr) {
+ Setup(convBwdInputDims);
+ }
+ }
+ ~MklConv2DBwdInputPrimitive() {}
+
+ // Convolution backward filter (weights)
+ // diff_src_data: output data buffer of diff_src
+ // filter_data: input data buffer of filter (weights)
+ // diff_dst_data: input data buffer of dst
+ // Bias does not matter here
+ void Execute(const T* diff_src_data,
+ const T* filter_data, const T* diff_dst_data) {
+ context_.diff_src_mem->set_data_handle(
+ static_cast<T*>(const_cast<T*>(diff_src_data)));
+ context_.filter_mem->set_data_handle(
+ static_cast<T*>(const_cast<T*>(filter_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<T*>(const_cast<T*>(diff_dst_data)));
+
+ context_.bwd_input_stream->submit(context_.bwd_input_primitives);
+
+ // set back data handle
+ context_.diff_src_mem->set_data_handle(DummyData);
+ context_.filter_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ return;
+ }
+
+ memory::format GetFilterMemoryFormat() const {
+ return context_.filter_fmt;
+ }
+
+ memory::format GetDiffDstMemoryFormat() const {
+ return context_.diff_dst_fmt;
+ }
+
+ std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
+ GetPrimitiveDesc() const {
+ return context_.bwd_input_pd;
+ }
+
+ private:
+ // Primitive reuse context for Conv2D Bwd Input op
+ struct ConvBwdInputContext {
+ // expected memory format for this primitive instance
+ memory::format filter_fmt;
+ memory::format diff_dst_fmt;
+
+ // MKLDNN memory
+ std::shared_ptr<mkldnn::memory> diff_src_mem;
+ std::shared_ptr<mkldnn::memory> filter_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+
+ // convolution primitive
+ std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
+ bwd_input_pd;
+ std::shared_ptr<mkldnn::primitive> conv_bwd_input;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::convolution_backward_data::desc> bwd_input_desc;
+ std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
+
+ // memory desc: forward & backward can share same memory::desc
+ std::shared_ptr<memory::desc> diff_src_md;
+ std::shared_ptr<memory::desc> filter_md;
+ std::shared_ptr<memory::desc> diff_dst_md;
+
+ // MKL pipeline
+ std::shared_ptr<mkldnn::stream> bwd_input_stream;
+ std::vector<mkldnn::primitive> bwd_input_primitives;
+
+ ConvBwdInputContext() :
+ filter_fmt(memory::format::any), diff_dst_fmt(memory::format::any),
+ diff_src_mem(nullptr), filter_mem(nullptr), diff_dst_mem(nullptr),
+ bwd_input_pd(nullptr), conv_bwd_input(nullptr),
+ bwd_input_desc(nullptr), fwd_desc(nullptr), fwd_pd(nullptr),
+ diff_src_md(nullptr), filter_md(nullptr), diff_dst_md(nullptr),
+ bwd_input_stream(nullptr) {
+ }
+ };
+
+
+ void Setup(const MklConvBwdInputParams& convBwdInputDims) {
+ // create memory descriptors for convolution data w/ no specified format
+ context_.diff_src_md.reset(new memory::desc(
+ {convBwdInputDims.diff_src_dims},
+ MklDnnType<T>(), memory::format::any));
+ context_.filter_md.reset(new memory::desc(
+ {convBwdInputDims.filter_dims},
+ MklDnnType<T>(), memory::format::any));
+ context_.diff_dst_md.reset(new memory::desc(
+ {convBwdInputDims.diff_dst_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ // create convolution primitives
+ context_.bwd_input_desc.reset(new convolution_backward_data::desc(
+ convolution_direct, *context_.diff_src_md, *context_.filter_md,
+ *context_.diff_dst_md, convBwdInputDims.strides,
+ convBwdInputDims.dilations, convBwdInputDims.padding_left,
+ convBwdInputDims.padding_right, convBwdInputDims.padding));
+
+ context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, *context_.diff_src_md, *context_.filter_md,
+ *context_.diff_dst_md, convBwdInputDims.strides,
+ convBwdInputDims.dilations, convBwdInputDims.padding_left,
+ convBwdInputDims.padding_right, convBwdInputDims.padding));
+
+ context_.fwd_pd.reset(new convolution_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+
+ // create backward conv prim desc
+ context_.bwd_input_pd.reset(
+ new convolution_backward_data::primitive_desc(
+ *context_.bwd_input_desc, cpu_engine_, *context_.fwd_pd));
+
+ // create memory primitive based on dummy data
+ context_.diff_src_mem.reset(new memory(
+ context_.bwd_input_pd.get()->diff_src_primitive_desc(), DummyData));
+ context_.filter_mem.reset(new memory(
+ context_.bwd_input_pd.get()->weights_primitive_desc(), DummyData));
+ context_.diff_dst_mem.reset(new memory(
+ context_.bwd_input_pd.get()->diff_dst_primitive_desc(), DummyData));
+
+ // store the expected memory format
+ context_.filter_fmt = static_cast<memory::format>(
+ context_.bwd_input_pd.get()->weights_primitive_desc().desc().data.format);
+ context_.diff_dst_fmt = static_cast<memory::format>(
+ context_.bwd_input_pd.get()->diff_dst_primitive_desc().desc().data.format);
+
+ // create convolution primitive and add it to net
+ context_.conv_bwd_input.reset(new convolution_backward_data(
+ *context_.bwd_input_pd, *context_.diff_dst_mem,
+ *context_.filter_mem, *context_.diff_src_mem));
+
+ context_.bwd_input_primitives.push_back(*context_.conv_bwd_input);
+ }
+
+ struct ConvBwdInputContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
+ private:
+ MklConv2DBwdInputPrimitiveFactory() {}
+ ~MklConv2DBwdInputPrimitiveFactory() {}
+
+ public:
+ static MklConv2DBwdInputPrimitive<T>* Get(
+ const MklConvBwdInputParams& convBwdInputDims) {
+ MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr;
+
+ // look into the pool for reusable primitive
+ conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> (
+ MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput(
+ convBwdInputDims));
+
+ if (conv2d_bwd_input == nullptr) {
+ conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>(
+ convBwdInputDims);
+ MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput(
+ convBwdInputDims, conv2d_bwd_input);
+ }
+ return conv2d_bwd_input;
+ }
+
+ private:
+ static MklConv2DBwdInputPrimitiveFactory& GetInstance() {
+ static MklConv2DBwdInputPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ static std::string CreateKey(
+ const MklConvBwdInputParams& convBwdInputDims) {
+ std::string prefix = "conv2d_bwd_input";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
+ key_creator.AddAsKey(convBwdInputDims.filter_dims);
+ key_creator.AddAsKey(convBwdInputDims.diff_dst_dims);
+ key_creator.AddAsKey(convBwdInputDims.strides);
+ key_creator.AddAsKey(convBwdInputDims.dilations);
+ key_creator.AddAsKey(convBwdInputDims.padding_left);
+ key_creator.AddAsKey(convBwdInputDims.padding_right);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetConv2dBwdInput(
+ const MklConvBwdInputParams& convBwdInputDims) {
+ std::string key = CreateKey(convBwdInputDims);
+ return this->GetOp(key);
+ }
+
+ void SetConv2dBwdInput(
+ const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
+ std::string key = CreateKey(convBwdInputDims);
+ this->SetOp(key, op);
+ }
+};
+
+#endif
+
#ifdef INTEL_MKL_ML
template <typename Device, class T>
@@ -365,13 +602,173 @@ class MklConv2DCustomBackpropInputOp
: public MklConv2DBackpropCommonOp<Device, T> {
public:
explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {}
+ : MklConv2DBackpropCommonOp<Device, T>(context) {
+ }
+
~MklConv2DCustomBackpropInputOp() {}
+ void Compute(OpKernelContext* context) {
+ try {
+ MklDnnData<T> filter(&cpu_engine);
+ MklDnnData<T> diff_dst(&cpu_engine);
+
+ // Input tensors
+ const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
+ const Tensor& src_tensor = MklGetInput(context, kInputIdx);
+ const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
+ const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
+
+ MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
+ GetMklShape(context, kInputIdx, &src_mkl_shape);
+ GetMklShape(context, kFilterIdx, &filter_mkl_shape);
+ GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
+ // Allow operator-specific sanity checking of shapes.
+ ValidateMklShapes(src_mkl_shape, filter_mkl_shape,
+ diff_dst_mkl_shape);
+
+ // Allow operator-specific generation of shapes.
+ // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // tensor containing shape of filter. So filter.shape() is not
+ // a correct way to get filter shape. These operator-specific calls
+ // allow this class to handle this case.
+ TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
+ TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
+ TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
+
+ // Corner cases: output with 0 elements and 0 batch size.
+ Tensor* diff_src_tensor = nullptr;
+ if (src_tf_shape.num_elements() == 0 ||
+ filter_tf_shape.num_elements() == 0 ||
+ diff_dst_tf_shape.num_elements() == 0) {
+ MklDnnShape diff_src_mkl_shape;
+ diff_src_mkl_shape.SetMklTensor(false);
+ TensorShape diff_src_tf_shape = GetOutputTfShape(
+ src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
+ const int kOutputIdx = 0;
+ AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor,
+ diff_src_tf_shape, diff_src_mkl_shape);
+ CHECK_NOTNULL(diff_src_tensor);
+
+ // if output tensor has more than 0 elements, we need to 0 them out.
+ auto diff_src_data = diff_src_tensor->flat<T>().data();
+ for (size_t i = 0; i < diff_src_tf_shape.num_elements(); ++i) {
+ diff_src_data[i] = 0;
+ }
+ return;
+ }
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with postfix tf_order.
+ memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
+ memory::dims padding_left, padding_right, dilations, strides;
+ memory::dims fwd_output_dims, fwd_output_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
+ this->data_format_, this->dilations_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
+ &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims,
+ &padding_left, &padding_right);
+ if (!context->status().ok()) return;
+
+ // Create Convolution forward descriptor since Convolution backward
+ // API needs it. For that, we first need to create input, filter
+ // and output memory descriptors.
+ auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+
+ // If filter is in MKL layout, then simply grab filter layout;
+ // otherwise, construct filter in TF layout.
+ // For TF layout, filter is in HWIO format.
+ auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
+ ? filter_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ memory::format::hwio);
+
+ conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ if (!context->status().ok()) return;
+ auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor()
+ ? diff_dst_mkl_shape.GetMklLayout()
+ : memory::desc(diff_dst_dims,
+ MklDnnType<T>(), tf_fmt);
+
+ dilations[kDilationH] -= 1;
+ dilations[kDilationW] -= 1;
+
+ MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr;
+ conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
+ diff_dst_dims, strides, dilations, padding_left, padding_right,
+ TFPaddingToMklDnnPadding(this->padding_));
+ conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get(
+ convBwdInputDims);
+ auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc();
+
+ // allocate output tensor
+ auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
+ auto bwd_diff_src_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
+ auto bwd_diff_src_format = GetOutputFormat(tf_fmt);
+ MklDnnShape diff_src_mkl_shape;
+ diff_src_mkl_shape.SetMklTensor(true);
+ diff_src_mkl_shape.SetMklLayout(&diff_src_pd);
+ diff_src_mkl_shape.SetElemType(MklDnnType<T>());
+ diff_src_mkl_shape.SetTfLayout(bwd_diff_src_dims.size(),
+ bwd_diff_src_dims, bwd_diff_src_format);
+ TensorShape diff_src_tf_shape;
+ diff_src_tf_shape.AddDim(diff_src_pd.get_size() / sizeof(T));
+ AllocateOutputSetMklShape(context, 0, &diff_src_tensor,
+ diff_src_tf_shape, diff_src_mkl_shape);
+
+ T *diff_src_data = static_cast<T*>(const_cast<T*>(
+ diff_src_tensor->flat<T>().data()));
+
+ // check if filter and diff_dst need reorder
+ std::vector<primitive> net;
+ T* filter_data = nullptr;
+ if (fwd_filter_md.data.format !=
+ conv2d_bwd_input->GetFilterMemoryFormat()) {
+ filter.SetUsrMem(fwd_filter_md, &filter_tensor);
+ filter.CheckReorderToOpMem(
+ bwd_input_pd->weights_primitive_desc(),
+ &net);
+ filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
+ } else {
+ filter_data = static_cast<T*>(const_cast<T*>(
+ filter_tensor.flat<T>().data()));
+ }
+
+ T* diff_dst_data = nullptr;
+ if (diff_dst_md.data.format !=
+ conv2d_bwd_input->GetDiffDstMemoryFormat()) {
+ diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ diff_dst.CheckReorderToOpMem(
+ bwd_input_pd->diff_dst_primitive_desc(), &net);
+ diff_dst_data = static_cast<T*>(
+ diff_dst.GetOpMem().get_data_handle());
+ } else {
+ diff_dst_data = static_cast<T*>(const_cast<T*>(
+ diff_dst_tensor.flat<T>().data()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+
+ // execute convolution input bwd
+ conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
private:
- const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0,
- kInputIndex_OutBackProp = 2;
+ const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0;
const int kDilationH = 0, kDilationW = 1;
+ engine cpu_engine = engine(engine::cpu, 0);
+
+ // Validate input shapes.
+ // Function asserts that input shapes are valid.
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@@ -382,8 +779,7 @@ class MklConv2DCustomBackpropInputOp
<< "Conv2DBackpropInput: input should not be in MKL Layout";
}
- size_t GetInputTensorIndexWithSizes() { return kInputIndex_InputSizes; }
-
+ // Get TensorFlow shape of input tensor.
TensorShape MakeInputTfShape(OpKernelContext* context,
const Tensor& input_tensor) {
TensorShape input_tf_shape;
@@ -395,72 +791,32 @@ class MklConv2DCustomBackpropInputOp
return input_tf_shape;
}
+ // Get TensorFlow shape of filter tensor.
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
return GetTfShape(context, kInputIndex_Filter);
}
+ // Get the Tensorflow shape of Output (diff_src),
+ // which is same as shape of Conv2D 'input'.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
- // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
return input_shape;
}
+ // Get the Tensorflow shape of Output (diff_src),
+ // which is same as shape of Conv2D 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
- // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
return fwd_input_dims;
}
+ // Output layout is Tensorflow's layout in data format order.
memory::format GetOutputFormat(const memory::format data_format) {
- // Output layout is Tensorflow's layout in data format order.
return data_format;
}
- void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine,
- const convolution_forward::primitive_desc& conv_fwd_pd,
- MklDnnData<T>* input, MklDnnData<T>* filter,
- MklDnnData<T>* outbackprop, MklDnnData<T>* output,
- Tensor** output_tensor,
- const memory::dims& strides,
- const memory::dims& dilations,
- const memory::dims& padding_l,
- const memory::dims& padding_r, padding_kind padding,
- const memory::dims& bwd_output_dims,
- memory::format bwd_output_format) {
- CHECK_NOTNULL(context);
- CHECK_NOTNULL(input);
- CHECK_NOTNULL(filter);
- CHECK_NOTNULL(outbackprop);
- CHECK_NOTNULL(output);
- CHECK_NOTNULL(output_tensor);
-
- // Create convolution backward data primitive.
- // Use dilated convolution in case dilate rates are greater than zero.
- auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
- convolution_backward_data::desc(convolution_direct,
- output->GetOpMemDesc(), filter->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides,
- dilations, padding_l, padding_r, padding):
- convolution_backward_data::desc(convolution_direct,
- output->GetOpMemDesc(), filter->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(),
- strides, padding_l, padding_r, padding);
-
- auto bwd_pd = convolution_backward_data::primitive_desc(
- bwd_desc, cpu_engine, conv_fwd_pd);
-
- // Allocate output tensor in TensorFlow and MKL layout.
- AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format,
- output_tensor);
- CHECK_NOTNULL(*output_tensor);
- // Set buffer handle using allocated output tensor.
- output->SetUsrMemDataHandle(*output_tensor);
-
- PrepareAndExecutePrimitive(bwd_pd, filter, outbackprop, output);
- }
-
// Allocate output tensor.
void AllocateOutputTensor(
OpKernelContext* context,
@@ -487,22 +843,6 @@ class MklConv2DCustomBackpropInputOp
AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
output_mkl_shape);
}
-
- // Prepare and execute net - checks for input and output reorders.
- void PrepareAndExecutePrimitive(
- const convolution_backward_data::primitive_desc& conv_pd,
- MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
- // Create reorders between user layout and MKL layout if it is needed and
- // add it to the net before convolution.
- std::vector<primitive> net;
- filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
- obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
-
- net.push_back(convolution_backward_data(
- conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
-
- stream(stream::kind::eager).submit(net).wait();
- }
};
#endif // INTEL_MKL_ML
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 8333a09316..5e1a5001dc 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <limits>
#include <string>
#include <vector>
+#include <memory>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -349,6 +350,7 @@ class MklDnnConvUtil {
}
};
+
/////////////////////////////////////////////////////////////////////
/// Common class that implements Conv2DBackpropFilter and Input
/////////////////////////////////////////////////////////////////////
@@ -388,227 +390,17 @@ class MklConv2DBackpropCommonOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
- void Compute(OpKernelContext* context) override {
- try {
- auto cpu_engine = engine(engine::cpu, 0);
-
- // Prepare common tensors for Conv2DBackpropInput and
- // Conv2DBackpropFilter.
- MklDnnData<T> input(&cpu_engine);
- MklDnnData<T> filter(&cpu_engine);
- MklDnnData<T> outbackprop(&cpu_engine);
- MklDnnData<T> output(&cpu_engine);
-
- // Input tensors
- const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
- const Tensor& input_tensor = MklGetInput(context, kInputIdx);
- const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
- const Tensor& outbprop_tensor = MklGetInput(context, kOutbpropIdx);
-
- MklDnnShape input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape;
- GetMklShape(context, kInputIdx, &input_mkl_shape);
- GetMklShape(context, kFilterIdx, &filter_mkl_shape);
- GetMklShape(context, kOutbpropIdx, &outbprop_mkl_shape);
- // Allow operator-specific sanity checking of shapes.
- ValidateMklShapes(input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape);
-
- // Allow operator-specific generation of shapes.
- // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
- // tensor containing shape of filter. So filter.shape() is not
- // a correct way to get filter shape. These operator-specific calls
- // allow this class to handle this case.
- TensorShape input_tf_shape = MakeInputTfShape(context, input_tensor);
- TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
- TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx);
-
- // Corner cases: output with 0 elements and 0 batch size.
- Tensor* output_tensor = nullptr;
- if (input_tf_shape.num_elements() == 0 ||
- filter_tf_shape.num_elements() == 0 ||
- outbprop_tf_shape.num_elements() == 0) {
- MklDnnShape output_mkl_shape;
- output_mkl_shape.SetMklTensor(false);
- TensorShape output_tf_shape = GetOutputTfShape(
- input_tf_shape, filter_tf_shape, outbprop_tf_shape);
- const int kOutputIdx = 0;
- AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor,
- output_tf_shape, output_mkl_shape);
- CHECK_NOTNULL(output_tensor);
-
- // if output tensor has more than 0 elements, we need to 0 them out.
- for (size_t i = 0; i < output_tf_shape.num_elements(); ++i) {
- output_tensor->flat<T>().data()[i] = 0;
- }
-
- return;
- }
-
- // By default, all dims are in MKL order. Only dims in TF order
- // are those with prefix tf_order.
- memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims;
- memory::dims padding_l, padding_r, dilations, strides, fwd_output_dims;
- memory::dims fwd_output_dims_tf_order;
-
- // Get forward convolution parameters.
- MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
- dilations_);
- conv_utl.GetConvFwdSizesInMklOrder(
- input_tf_shape, filter_tf_shape, &fwd_input_dims, &fwd_filter_dims,
- &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims,
- &padding_l, &padding_r);
- if (!context->status().ok()) return;
-
- // Create Convolution forward descriptor since Convolution backward
- // API needs it. For that, we first need to create input, filter
- // and output memory descriptors.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
- // If input is in MKL layout, then simply grab input layout; otherwise,
- // construct input TF layout. For TF layout, although input shape
- // required is in MKL-DNN order, the layout is Tensorflow's layout
- // (NHWC or NCHW depending on data format).
- auto fwd_input_md =
- input_mkl_shape.IsMklTensor()
- ? input_mkl_shape.GetMklLayout()
- : memory::desc(fwd_input_dims, MklDnnType<T>(), tf_fmt);
- // If filter is in MKL layout, then simply grab filter layout; otherwise
- // construct filter in TF layout. For TF layout, filter is in HWIO format.
- auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
- ? filter_mkl_shape.GetMklLayout()
- : memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
- // Tensorflow Output of Conv2D is in data_format order.
- auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), tf_fmt);
-
- const int kDilationH = 0, kDilationW = 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
- auto fwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0)?
- convolution_forward::desc(prop_kind::forward,
- convolution_direct, fwd_input_md,
- fwd_filter_md, fwd_out_md,
- strides, dilations, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_)) :
- convolution_forward::desc(prop_kind::forward,
- convolution_direct, fwd_input_md,
- fwd_filter_md, fwd_out_md,
- strides, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_));
- auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
-
- // Create memory for user data. Describe how the inputs and outputs of
- // Convolution look like. Also specify buffers containing actual input
- // and output data.
-
- // Since this is a common class for both Conv2DBackpropFilter and
- // Conv2DBackpropInput, we skip SetUsrMem call for input tensor (for
- // Conv2DBackpropInput) and for filter tensor (for
- // conv2DBackpropFilter) depending on which tensor is int32 type.
- size_t input_with_sizes = GetInputTensorIndexWithSizes();
- if (input_with_sizes != kInputIdx) {
- // Shape of Conv2DBackpropFilter's input is same as Conv2D input.
- input.SetUsrMem(fwd_input_md, &input_tensor);
- } else if (input_with_sizes != kFilterIdx) {
- // Shape of Conv2DBackpropInput's filter is same as Conv2D filter.
- filter.SetUsrMem(fwd_filter_md, &filter_tensor);
- }
-
- conv_utl.GetInputSizeInMklOrder(outbprop_tf_shape, &outbprop_dims);
- if (!context->status().ok()) return;
- if (outbprop_mkl_shape.IsMklTensor()) {
- // If outbackprop is in Mkl layout, then simply grab it.
- auto outbprop_md = outbprop_mkl_shape.GetMklLayout();
- outbackprop.SetUsrMem(outbprop_md, &outbprop_tensor);
- } else {
- // If outbackprop is in TensorFlow layout, then we need to create memory
- // descriptor for it. Outbackprop shape is data format order.
- outbackprop.SetUsrMem(outbprop_dims, tf_fmt, &outbprop_tensor);
- }
-
- // Operator specific call to get output shape and data_format.
- auto bwd_output_dims = GetOutputDims(fwd_input_dims, fwd_filter_dims);
- auto bwd_output_format = GetOutputFormat(tf_fmt);
- output.SetUsrMem(bwd_output_dims, bwd_output_format);
-
- // Create memory descriptors for convolution data w/ no specified format.
- input.SetOpMemDesc(fwd_input_dims, memory::format::any);
- filter.SetOpMemDesc(fwd_filter_dims, memory::format::any);
- outbackprop.SetOpMemDesc(outbprop_dims, memory::format::any);
- output.SetOpMemDesc(bwd_output_dims, memory::format::any);
-
- // Operator-specific call to create and execute primitive.
- CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter,
- &outbackprop, &output, &output_tensor,
- strides, dilations, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_),
- bwd_output_dims, bwd_output_format);
- } catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
- }
- }
-
- /// Pure virtual function to allow operator to check for validity of input
- /// shapes. Function asserts that input shapes are valid.
- virtual void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
- const MklDnnShape& filter_mkl_shape,
- const MklDnnShape& outbprop_mkl_shape) = 0;
-
- /// Operator-specific function that returns index of input that is
- /// representing input sizes. For Conv2DBackpropFilter it returns 1 since
- /// filter for this operator is filter shape. For Conv2DBackpropInput it
- /// returns 0 (for input).
- virtual size_t GetInputTensorIndexWithSizes() = 0;
-
- /// Get TensorFlow shape of input tensor.
- virtual TensorShape MakeInputTfShape(OpKernelContext* context,
- const Tensor& input_tensor) = 0;
-
- /// Get TensorFlow shape of filter tensor.
- virtual TensorShape MakeFilterTfShape(OpKernelContext* context,
- const Tensor& filter_tensor) = 0;
-
- /// Get the TensorFlow shape of output tensor.
- virtual TensorShape GetOutputTfShape(const TensorShape& input_shape,
- const TensorShape& filter_shape,
- const TensorShape& outbprop_shape) = 0;
-
- /// Get shape of output in MKL-DNN order. Computes shape of output from
- /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims).
- virtual const memory::dims& GetOutputDims(
- const memory::dims& fwd_input_dims,
- const memory::dims& fwd_filter_dims) = 0;
-
- /// Get data_format of output in MKL-DNN order. If output data format is
- /// same as input data format, then it simply returns value of data_format
- /// parameter as it is.
- virtual memory::format GetOutputFormat(const memory::format data_format) = 0;
-
- /// Create and execute the primitive storing output in the output_tensor.
- virtual void CreatePrimitive(OpKernelContext* context,
- const engine& cpu_engine,
- const convolution_forward::primitive_desc& conv_fwd_pd,
- MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop,
- MklDnnData<T>* output, Tensor** output_tensor, const memory::dims& strides,
- const memory::dims& dilations, const memory::dims& padding_l,
- const memory::dims& padding_r, padding_kind padding,
- const memory::dims& bwd_output_dims,
- memory::format bwd_output_format) = 0;
-
- // Get the data_format {NCHW, NHWC}
- TensorFormat GetTFDataFormat() { return data_format_; }
-
- private:
+ protected:
+ // data members accessible to derived classes.
std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
- TensorFormat data_format_;
+ TensorFormat data_format_; // NCHW or NHWC
};
+
#endif // INTEL_MKL_ML
+
/////////////////////////////////////////////////////////////////////
/// Dummy Mkl op that is just used for operators that are intermediate
/// output of node fusion in the graph