aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc174
1 files changed, 103 insertions, 71 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 38e014d68e..a501ce2c93 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#ifndef INTEL_MKL_ML_ONLY
-/// utility classes enabling primitive reuse for backward conv2d ops.
+/// utility classes enabling primitive reuse for backward conv ops.
struct MklConvBwdInputParams {
memory::dims diff_src_dims;
memory::dims filter_dims;
@@ -83,11 +83,11 @@ struct MklConvBwdInputParams {
};
template <typename T>
-class MklConv2DBwdInputPrimitive : public MklPrimitive {
+class MklConvBwdInputPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdInputPrimitive(
- const MklConvBwdInputParams& convBwdInputDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdInputPrimitive(
+ const MklConvBwdInputParams& convBwdInputDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_input_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -95,7 +95,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
Setup(convBwdInputDims);
}
}
- ~MklConv2DBwdInputPrimitive() {}
+ ~MklConvBwdInputPrimitive() {}
// Convolution backward filter (weights)
// diff_src_data: output data buffer of diff_src
@@ -134,7 +134,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
private:
- // Primitive reuse context for Conv2D Bwd Input op
+ // Primitive reuse context for Conv Bwd Input op
struct ConvBwdInputContext {
// expected memory format for this primitive instance
memory::format filter_fmt;
@@ -174,7 +174,6 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
};
-
void Setup(const MklConvBwdInputParams& convBwdInputDims) {
// create memory descriptors for convolution data w/ no specified format
context_.diff_src_md.reset(new memory::desc(
@@ -235,38 +234,41 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
private:
- MklConv2DBwdInputPrimitiveFactory() {}
- ~MklConv2DBwdInputPrimitiveFactory() {}
+ MklConvBwdInputPrimitiveFactory() {}
+ ~MklConvBwdInputPrimitiveFactory() {}
public:
- static MklConv2DBwdInputPrimitive<T>* Get(
- const MklConvBwdInputParams& convBwdInputDims) {
- MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr;
-
- // look into the pool for reusable primitive
- conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> (
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput(
- convBwdInputDims));
-
- if (conv2d_bwd_input == nullptr) {
- conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>(
- convBwdInputDims);
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput(
- convBwdInputDims, conv2d_bwd_input);
+ static MklConvBwdInputPrimitive<T>* Get(
+ const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) {
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
+
+ if (do_not_cache) { /* Always allocate primitive */
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
+ convBwdInputDims));
+ if (conv_bwd_input == nullptr) {
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
+ convBwdInputDims, conv_bwd_input);
+ }
}
- return conv2d_bwd_input;
+
+ return conv_bwd_input;
}
private:
- static MklConv2DBwdInputPrimitiveFactory& GetInstance() {
- static MklConv2DBwdInputPrimitiveFactory instance_;
+ static MklConvBwdInputPrimitiveFactory& GetInstance() {
+ static MklConvBwdInputPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) {
- string prefix = "conv2d_bwd_input";
+ string prefix = "conv_bwd_input";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
@@ -279,14 +281,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims) {
+ MklPrimitive* GetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims) {
string key = CreateKey(convBwdInputDims);
return this->GetOp(key);
}
- void SetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
+ void SetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdInputDims);
this->SetOp(key, op);
}
@@ -594,23 +595,34 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
+#undef REGISTER_MKL_CPU_KERNELS
+
#else
template <typename Device, class T>
-class MklConv2DCustomBackpropInputOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropInputOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropInputOp() {}
+ ~MklConvCustomBackpropInputOp() {}
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
+ // This flag indicate Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -626,7 +638,7 @@ class MklConv2DCustomBackpropInputOp
diff_dst_mkl_shape);
// Allow operator-specific generation of shapes.
- // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // E.g., ConvBackpropFilter gets filter as filter_sizes. It is a
// tensor containing shape of filter. So filter.shape() is not
// a correct way to get filter shape. These operator-specific calls
// allow this class to handle this case.
@@ -655,6 +667,7 @@ class MklConv2DCustomBackpropInputOp
}
return;
}
+
// By default, all dims are in MKL order. Only dims in TF order
// are those with postfix tf_order.
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
@@ -673,15 +686,18 @@ class MklConv2DCustomBackpropInputOp
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
// If filter is in MKL layout, then simply grab filter layout;
// otherwise, construct filter in TF layout.
// For TF layout, filter is in HWIO format.
auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
- ? filter_mkl_shape.GetMklLayout()
- : memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
+ ? filter_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
@@ -689,18 +705,25 @@ class MklConv2DCustomBackpropInputOp
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims,
MklDnnType<T>(), tf_fmt);
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr;
- conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
diff_dst_dims, strides, dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get(
- convBwdInputDims);
- auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc();
+
+ // We don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
+ // includes potentialy large buffers. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(fwd_filter_dims, strides));
+ conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
+ do_not_cache);
+ auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
@@ -723,7 +746,7 @@ class MklConv2DCustomBackpropInputOp
// check if filter and diff_dst need reorder
T* filter_data = nullptr;
if (fwd_filter_md.data.format !=
- conv2d_bwd_input->GetFilterMemoryFormat()) {
+ conv_bwd_input->GetFilterMemoryFormat()) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
@@ -733,8 +756,7 @@ class MklConv2DCustomBackpropInputOp
}
T* diff_dst_data = nullptr;
- if (diff_dst_md.data.format !=
- conv2d_bwd_input->GetDiffDstMemoryFormat()) {
+ if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -745,7 +767,12 @@ class MklConv2DCustomBackpropInputOp
}
// execute convolution input bwd
- conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+ conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) {
+ delete conv_bwd_input;
+ }
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -770,7 +797,7 @@ class MklConv2DCustomBackpropInputOp
// of the Tensor and never an actual tensor. So it will never be in MKL
// layout.
CHECK(!input_mkl_shape.IsMklTensor())
- << "Conv2DBackpropInput: input should not be in MKL Layout";
+ << "ConvBackpropInput: input should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -778,10 +805,10 @@ class MklConv2DCustomBackpropInputOp
const Tensor& input_tensor) {
TensorShape input_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
- CHECK_EQ(
- TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape)
- .ok(),
- true);
+ // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64
+ // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for
+ // input_tensor.
+ CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true);
return input_tf_shape;
}
@@ -792,7 +819,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
@@ -800,7 +827,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
return fwd_input_dims;
@@ -839,17 +866,22 @@ class MklConv2DCustomBackpropInputOp
}
};
-#endif // INTEL_MKL_ML_ONLY
-
-#define REGISTER_MKL_CPU_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
#undef REGISTER_MKL_CPU_KERNELS
+#endif // INTEL_MKL_ML_ONLY
+
} // namespace tensorflow
#endif // INTEL_MKL