aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc146
1 files changed, 95 insertions, 51 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 1440da8f82..f0818eb96d 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -493,6 +493,7 @@ class MklConv2DOp : public OpKernel {
~MklConv2DOp() {}
explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
@@ -509,6 +510,20 @@ class MklConv2DOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
}
void Compute(OpKernelContext* context) override {
@@ -530,17 +545,19 @@ class MklConv2DOp : public OpKernel {
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> output(&cpu_engine);
- memory::dims src_dims, filter_dims, padding_l, padding_r, strides;
+ memory::dims src_dims, filter_dims, padding_l, padding_r,
+ dilations, strides;
memory::dims output_dims_tf_order, output_dims_mkl_order;
// Get shapes of input tensors in MKL-DNN order
- MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
+ dilations_);
auto src_tf_shape = GetTfShape(context, kInputIndex_Src);
auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter);
conv_utl.GetConvFwdSizesInMklOrder(
src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
- &output_dims_tf_order, &output_dims_mkl_order, &padding_l,
- &padding_r);
+ &dilations, &output_dims_tf_order, &output_dims_mkl_order,
+ &padding_l, &padding_r);
if (!context->status().ok()) return;
// Check for corner case - if there is nothing to compute, return.
@@ -553,6 +570,7 @@ class MklConv2DOp : public OpKernel {
// Need semantics for Null MKL tensor
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
+
AllocateOutputSetMklShape(context, kOutputIndex_Dst, &output_tensor,
src_tf_shape, output_mkl_shape);
@@ -596,55 +614,79 @@ class MklConv2DOp : public OpKernel {
filter.SetOpMemDesc(filter_dims, memory::format::any);
output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
- // If bias is enabled, then do the same steps as above for bias.
+ // MKLDNN dilation starts from 0.
+ dilations[kDilationH] -= 1;
+ dilations[kDilationW] -= 1;
+
if (biasEnabled) {
- MklDnnData<T> bias(&cpu_engine);
- memory::dims bias_size;
- conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_size);
- const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
- bias.SetUsrMem(bias_size, memory::format::x, &bias_tensor);
- bias.SetOpMemDesc(bias_size, memory::format::any);
-
- // Create convolution primitive with Bias.
- auto conv_desc = convolution_forward::desc(
- prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
- filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(),
- strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
-
- auto conv_prim_desc =
- convolution_forward::primitive_desc(conv_desc, cpu_engine);
- AllocateOutputTensor(context, conv_prim_desc, output_dims_mkl_order,
- tf_fmt, &output_tensor);
- // Set data handle for output.
- output.SetUsrMemDataHandle(output_tensor);
-
- Tensor* filter_out_tensor = nullptr;
- AllocateFilterOutputTensor(context, conv_prim_desc,
- TFShapeToMklDnnDims(filter_tf_shape),
- &filter_out_tensor);
-
- PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output,
- filter_out_tensor);
+ // Create convolution primitive with Bias.
+ MklDnnData<T> bias(&cpu_engine);
+ memory::dims bias_size;
+ conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_size);
+ const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
+ bias.SetUsrMem(bias_size, memory::format::x, &bias_tensor);
+ bias.SetOpMemDesc(bias_size, memory::format::any);
+
+ // Create convolution primitive with Bias.
+ // Use MKLDNN dilated convolution in case of dilated rate (>0).
+ auto conv_desc = (dilations[kDilationH] > 0 ||
+ dilations[kDilationW] > 0) ?
+ convolution_forward::desc(prop_kind::forward,
+ convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), bias.GetOpMemDesc(),
+ output.GetOpMemDesc(), strides, dilations,
+ padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_)):
+ convolution_forward::desc(prop_kind::forward,
+ convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), bias.GetOpMemDesc(),
+ output.GetOpMemDesc(), strides,
+ padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+
+ auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
+ cpu_engine);
+ AllocateOutputTensor(context, conv_prim_desc,
+ output_dims_mkl_order, tf_fmt, &output_tensor);
+ // Set data handle for output.
+ output.SetUsrMemDataHandle(output_tensor);
+
+ Tensor* filter_out_tensor = nullptr;
+ AllocateFilterOutputTensor(context, conv_prim_desc,
+ TFShapeToMklDnnDims(filter_tf_shape),
+ &filter_out_tensor);
+
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output,
+ filter_out_tensor);
} else {
- // Create convolution primitive without Bias.
- auto conv_desc = convolution_forward::desc(
- prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
- filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l,
- padding_r, TFPaddingToMklDnnPadding(padding_));
-
- auto conv_prim_desc =
- convolution_forward::primitive_desc(conv_desc, cpu_engine);
- AllocateOutputTensor(context, conv_prim_desc, output_dims_mkl_order,
- tf_fmt, &output_tensor);
- // Set data handle for output.
- output.SetUsrMemDataHandle(output_tensor);
-
- Tensor* filter_out_tensor = nullptr;
- AllocateFilterOutputTensor(context, conv_prim_desc,
- TFShapeToMklDnnDims(filter_tf_shape),
- &filter_out_tensor);
- PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output,
- filter_out_tensor);
+ // Create convolution primitive without Bias.
+ // Use MKLDNN dilated convolution in case of dilated rate (>0).
+ auto conv_desc = (dilations[kDilationH] > 0 ||
+ dilations[kDilationW] > 0) ?
+ convolution_forward::desc(prop_kind::forward,
+ convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), output.GetOpMemDesc(),
+ strides, dilations, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_)):
+ convolution_forward::desc(prop_kind::forward,
+ convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), output.GetOpMemDesc(),
+ strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+
+ auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
+ cpu_engine);
+ AllocateOutputTensor(context, conv_prim_desc, output_dims_mkl_order,
+ tf_fmt, &output_tensor);
+ // Set data handle for output.
+ output.SetUsrMemDataHandle(output_tensor);
+
+ Tensor* filter_out_tensor = nullptr;
+ AllocateFilterOutputTensor(context, conv_prim_desc,
+ TFShapeToMklDnnDims(filter_tf_shape),
+ &filter_out_tensor);
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter,
+ nullptr, &output, filter_out_tensor);
}
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
@@ -658,10 +700,12 @@ class MklConv2DOp : public OpKernel {
private:
std::vector<int32> strides_;
+ std::vector<int32> dilations_;
Padding padding_;
TensorFormat data_format_;
const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2;
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
+ const int kDilationH = 0, kDilationW = 1;
// Allocate output tensor.
void AllocateOutputTensor(