aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc193
1 files changed, 97 insertions, 96 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index cede0b9dd6..b568973220 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -70,23 +70,25 @@ struct MklConvFwdParams {
memory::dims padding_left;
memory::dims padding_right;
- MklConvFwdParams(memory::dims src_dims,
- memory::dims filter_dims, memory::dims bias_dims,
- memory::dims dst_dims, memory::dims strides,
- memory::dims dilations, memory::dims padding_left,
- memory::dims padding_right) :
- src_dims(src_dims), filter_dims(filter_dims),
- bias_dims(bias_dims), dst_dims(dst_dims),
- strides(strides), dilations(dilations),
- padding_left(padding_left), padding_right(padding_right) {
- }
+ MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
+ memory::dims bias_dims, memory::dims dst_dims,
+ memory::dims strides, memory::dims dilations,
+ memory::dims padding_left, memory::dims padding_right)
+ : src_dims(src_dims),
+ filter_dims(filter_dims),
+ bias_dims(bias_dims),
+ dst_dims(dst_dims),
+ strides(strides),
+ dilations(dilations),
+ padding_left(padding_left),
+ padding_right(padding_right) {}
};
template <typename T>
-class MklConv2DFwdPrimitive: public MklPrimitive {
+class MklConv2DFwdPrimitive : public MklPrimitive {
public:
- explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive
if (context_.conv_fwd == nullptr) {
@@ -101,8 +103,8 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
// filter_data: input data buffer of filter (weights)
// bias_data: input data buffer of bias
// dst_data: output data buffer of dst
- void Execute(const T* src_data, const T* filter_data,
- const T* bias_data, const T* dst_data) {
+ void Execute(const T* src_data, const T* filter_data, const T* bias_data,
+ const T* dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.filter_mem->set_data_handle(
@@ -126,8 +128,7 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
// src_data: input data buffer of src
// filter_data: input data buffer of filter (weights)
// dst_data: output data buffer of dst
- void Execute(const T* src_data, const T* filter_data,
- const T* dst_data) {
+ void Execute(const T* src_data, const T* filter_data, const T* dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.filter_mem->set_data_handle(
@@ -142,13 +143,9 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
context_.dst_mem->set_data_handle(DummyData);
}
- memory::format GetSrcMemoryFormat() const {
- return context_.src_fmt;
- }
+ memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
- memory::format GetFilterMemoryFormat() const {
- return context_.filter_fmt;
- }
+ memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
GetPrimitiveDesc() const {
@@ -184,43 +181,50 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
- ConvFwdContext() :
- src_fmt(memory::format::any), filter_fmt(memory::format::any),
- src_mem(nullptr), filter_mem(nullptr), bias_mem(nullptr),
- dst_mem(nullptr), fwd_desc(nullptr),
- src_md(nullptr), filter_md(nullptr), bias_md(nullptr),
- fwd_pd(nullptr), conv_fwd(nullptr), fwd_stream(nullptr) {
- }
+ ConvFwdContext()
+ : src_fmt(memory::format::any),
+ filter_fmt(memory::format::any),
+ src_mem(nullptr),
+ filter_mem(nullptr),
+ bias_mem(nullptr),
+ dst_mem(nullptr),
+ fwd_desc(nullptr),
+ src_md(nullptr),
+ filter_md(nullptr),
+ bias_md(nullptr),
+ fwd_pd(nullptr),
+ conv_fwd(nullptr),
+ fwd_stream(nullptr) {}
};
void Setup(const MklConvFwdParams& convFwdDims) {
// create memory descriptors for convolution data w/ no specified format
- context_.src_md.reset(new memory::desc({convFwdDims.src_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.src_md.reset(new memory::desc(
+ {convFwdDims.src_dims}, MklDnnType<T>(), memory::format::any));
- context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.filter_md.reset(new memory::desc(
+ {convFwdDims.filter_dims}, MklDnnType<T>(), memory::format::any));
- context_.dst_md.reset(new memory::desc({convFwdDims.dst_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.dst_md.reset(new memory::desc(
+ {convFwdDims.dst_dims}, MklDnnType<T>(), memory::format::any));
if (!convFwdDims.bias_dims.empty())
- context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.bias_md.reset(new memory::desc(
+ {convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::any));
// create a convolution
if (!convFwdDims.bias_dims.empty()) {
- context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
- convolution_direct, *context_.src_md, *context_.filter_md,
- *context_.bias_md, *context_.dst_md,
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct, *context_.src_md,
+ *context_.filter_md, *context_.bias_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero));
} else {
- context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
- convolution_direct, *context_.src_md, *context_.filter_md,
- *context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
- convFwdDims.padding_left, convFwdDims.padding_right,
- padding_kind::zero));
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct, *context_.src_md,
+ *context_.filter_md, *context_.dst_md, convFwdDims.strides,
+ convFwdDims.dilations, convFwdDims.padding_left,
+ convFwdDims.padding_right, padding_kind::zero));
}
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
@@ -234,24 +238,26 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
// create memory primitive based on dummy data
- context_.src_mem.reset(new memory(
- context_.fwd_pd.get()->src_primitive_desc(), DummyData));
- context_.filter_mem.reset(new memory(
- context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
- context_.dst_mem.reset(new memory(
- context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
+ context_.src_mem.reset(
+ new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
+ context_.filter_mem.reset(
+ new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
+ context_.dst_mem.reset(
+ new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
// create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) {
- context_.bias_mem.reset(new memory({{{convFwdDims.bias_dims},
- MklDnnType<T>(), memory::format::x}, cpu_engine_}, DummyData));
- context_.conv_fwd.reset(new convolution_forward(
- *context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
- *context_.bias_mem, *context_.dst_mem));
+ context_.bias_mem.reset(new memory(
+ {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x},
+ cpu_engine_},
+ DummyData));
+ context_.conv_fwd.reset(new convolution_forward(
+ *context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
+ *context_.bias_mem, *context_.dst_mem));
} else {
- context_.conv_fwd.reset(new convolution_forward(
- *context_.fwd_pd, *context_.src_mem,
- *context_.filter_mem, *context_.dst_mem));
+ context_.conv_fwd.reset(
+ new convolution_forward(*context_.fwd_pd, *context_.src_mem,
+ *context_.filter_mem, *context_.dst_mem));
}
context_.fwd_primitives.push_back(*context_.conv_fwd);
@@ -266,19 +272,19 @@ template <typename T>
class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
-
- // try to find a suitable one in pool
- conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*> (
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
- convFwdDims));
-
- if (conv2d_fwd == nullptr) {
- conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(
- convFwdDims, conv2d_fwd);
- }
- return conv2d_fwd;
+ MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+
+ // try to find a suitable one in pool
+ conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
+ MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
+ convFwdDims));
+
+ if (conv2d_fwd == nullptr) {
+ conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
+ MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
+ conv2d_fwd);
+ }
+ return conv2d_fwd;
}
private:
@@ -312,7 +318,7 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return this->GetOp(key);
}
- void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive *op) {
+ void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
std::string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
@@ -865,22 +871,24 @@ class MklConv2DOp : public OpKernel {
dilations[kDilationW] -= 1;
// get a conv2d fwd from primitive pool
- MklConv2DFwdPrimitive<T> *conv2d_fwd = nullptr;
+ MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
if (biasEnabled) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
- dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
+ dst_dims_mkl_order, strides, dilations,
+ padding_left, padding_right);
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
- dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
+ dst_dims_mkl_order, strides, dilations,
+ padding_left, padding_right);
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
}
// allocate output tensors output_tensor and filter_out_tensor
- std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
- conv_fwd_pd = conv2d_fwd->GetPrimitiveDesc();
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
+ conv2d_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr;
@@ -891,31 +899,25 @@ class MklConv2DOp : public OpKernel {
T* dst_data = static_cast<T*>(dst_tensor->flat<T>().data());
// check whether src/filter need reorder
- std::vector<primitive> net;
T *src_data = nullptr;
if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
- src.CheckReorderToOpMem(
- conv_fwd_pd.get()->src_primitive_desc(), &net);
+ src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
- src_data = static_cast<T*>(const_cast<T*>(
- src_tensor.flat<T>().data()));
+ src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
- T *filter_data = nullptr;
+ T* filter_data = nullptr;
if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
filter.SetUsrMem(filter_md, &filter_tensor);
- filter.CheckReorderToOpMem(
- conv_fwd_pd.get()->weights_primitive_desc(),
- filter.GetTensorBuffer(filter_out_tensor), &net);
+ filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
+ filter.GetTensorBuffer(filter_out_tensor));
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
} else {
- filter_data = static_cast<T*>(const_cast<T*>(
- filter_tensor.flat<T>().data()));
+ filter_data =
+ static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
}
- stream(stream::kind::eager).submit(net).wait();
-
// execute convolution
if (biasEnabled) {
@@ -1010,16 +1012,15 @@ class MklConv2DOp : public OpKernel {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution. No need to check for output
// reorder as we propagate output layout to the next layer.
- std::vector<primitive> net;
- src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
+ src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc());
// rather than re-order to a temp buffer, reorder directly to the
// filter output tensor
filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
- filter->GetTensorBuffer(filter_out_tensor),
- &net);
+ filter->GetTensorBuffer(filter_out_tensor));
// Create convolution primitive and add it to net.
+ std::vector<primitive> net;
if (bias) {
CHECK_EQ(biasEnabled, true);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),