aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc28
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc12
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc20
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc51
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc59
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc129
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h132
-rw-r--r--tensorflow/core/ops/nn_ops.cc98
-rw-r--r--tensorflow/core/util/mkl_util.h16
9 files changed, 430 insertions, 115 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 833592caab..7e501c1717 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -334,6 +334,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
+
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
@@ -546,14 +547,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
- !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}
// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
- !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
@@ -2408,6 +2409,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.addn = "AddN";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
+ csinfo_.avg_pool3d = "AvgPool3D";
+ csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
@@ -2429,6 +2432,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.matmul = "MatMul";
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
+ csinfo_.max_pool3d = "MaxPool3D";
+ csinfo_.max_pool3d_grad = "MaxPool3DGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
@@ -2463,6 +2468,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.avg_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
+ CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.concat,
mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite});
@@ -2513,7 +2524,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, MaxpoolGradRewrite});
-
+ rinfo_.push_back({csinfo_.max_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
+ CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
+ rinfo_.push_back({csinfo_.max_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite});
@@ -2550,6 +2566,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
+ wsinfo_.push_back
+ ({csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
// Add a rule for merging nodes
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
@@ -2617,6 +2635,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string add;
string avg_pool;
string avg_pool_grad;
+ string avg_pool3d;
+ string avg_pool3d_grad;
string bias_add;
string bias_add_grad;
string concat;
@@ -2637,6 +2657,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string matmul;
string max_pool;
string max_pool_grad;
+ string max_pool3d;
+ string max_pool3d_grad;
string maximum;
string mkl_conv2d;
string mkl_conv2d_grad_input;
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index aa39af637f..b67a321fc1 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -175,7 +175,11 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
- if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
+ // TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
+ // using data_format. This code might be redundant.
+ if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
@@ -254,9 +258,13 @@ Status MklToTfConversionPass::InsertInputConversionNode(
}
}
+ // TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
+ // seem to be using data_format. This code might be redundant.
string data_format;
if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
- Status::OK()) {
+ Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index 28edf51546..20aa1f7ea1 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -392,16 +392,28 @@ class MklAddNOp : public OpKernel {
memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
auto src1_tf_data_format =
MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
- auto src2_dims =
- TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
+ memory::dims src2_dims;
+ if (src2_tensor.dims() == 4) {
+ src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ } else {
+ src2_dims = TFShapeToMklDnnDimsInNCDHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ }
md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
} else if (input2_in_mkl_format && !input1_in_mkl_format) {
// Same comment as above.
memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
auto src2_tf_data_format =
MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
- auto src1_dims =
- TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
+ memory::dims src1_dims;
+ if (src1_tensor.dims() == 4) {
+ src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ } else {
+ src1_dims = TFShapeToMklDnnDimsInNCDHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ }
md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
md2 = src2_mkl_shape.GetMklLayout();
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index 969baecc51..2409f7e9dc 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -453,6 +453,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
// initialize variables for the pooling op
MklPoolParameters pool_params;
+ // check whether pooling is 2D or 3D
+ bool is_pool2d = (this->ksize_.size() == 4);
// Get the input tensor and initialize the pooling parameters
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
@@ -473,23 +475,22 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
}
memory::dims filter_dims, strides, padding_left, padding_right;
+ // Get src/filter/stride/padding information
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
// Get the input memory descriptor
- memory::desc input_md =
- dnn_shape_input.IsMklTensor()
- ? dnn_shape_input.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
- this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
-
- // Get src/filter/stride/padding information
memory::dims src_dims =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
+ this->data_format_tf_);
+ memory::desc input_md = dnn_shape_input.IsMklTensor()
+ ? dnn_shape_input.GetMklLayout()
+ : memory::desc(src_dims, MklDnnType<T>(),
+ this->data_format_mkldnn_);
// Get an average pooling primitive from the op pool
MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
@@ -562,24 +563,30 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
for (int i = 0; i < orig_input_tensor.NumElements(); i++) {
orig_input_shape.AddDim(shape_vec(i));
}
+
+ bool is_pool2d = (this->ksize_.size() == 4);
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
+ this->data_format_tf_);
memory::dims diff_dst_dims =
grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
+ this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
@@ -664,6 +671,18 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklAvgPoolingGradOp
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklAvgPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3DGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklAvgPoolingGradOp<CPUDevice, float>);
+
#endif // INTEL_MKL_ML_ONLY
REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index e149f003e5..256d48f4d5 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -524,6 +524,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
// initialize variables for the pooling op
MklPoolParameters pool_params;
+ // check whether pooling is 2D or 3D
+ bool is_pool2d = (this->ksize_.size() == 4);
// Get the input tensor and initialize the pooling parameters
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
@@ -547,20 +549,26 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
memory::desc input_md =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
- this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
+ : is_pool2d ? memory::desc(
+ TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_)
+ : memory::desc(
+ TFShapeToMklDnnDimsInNCDHW(
+ input_tensor_shape, this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_);
// Get src/filter/stride/padding information
memory::dims src_dims =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
- this->data_format_tf_);
-
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
+ this->data_format_tf_);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
// Get a pooling op from the cached pool
MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
@@ -663,23 +671,30 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolParameters pool_params;
TensorShape orig_input_shape = orig_input_tensor.shape();
+
+ bool is_pool2d = (this->ksize_.size() == 4);
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
- memory::dims diff_dst_dims =
- grad_mkl_shape.IsMklTensor()
- ? grad_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
- this->data_format_tf_);
memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
+ this->data_format_tf_);
+
+ memory::dims diff_dst_dims =
+ grad_mkl_shape.IsMklTensor()
+ ? grad_mkl_shape.GetSizesAsMklDnnDims()
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
+ this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
@@ -715,7 +730,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
void* ws_data = static_cast<void*>(
const_cast<uint8*>(workspace_tensor.flat<uint8>().data()));
- ;
+
auto ws_md =
pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc();
if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) {
@@ -817,6 +832,18 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklMaxPoolingGradOp
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklMaxPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklMaxPoolingGradOp<CPUDevice, float>);
+
#endif // INTEL_MKL_ML_ONLY
REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index d7ad3f9dcd..ec6d241e17 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -24,7 +24,7 @@ limitations under the License.
namespace tensorflow {
-#ifndef INTEL_MKL_ML
+#ifndef INTEL_MKL_ML_ONLY
using mkldnn::pooling_avg;
using mkldnn::pooling_avg_exclude_padding;
@@ -46,9 +46,10 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
+ bool is_2d = (fwdParams.src_dims.size() == 4);
context_.src_md.reset(
new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
- get_desired_format(fwdParams.src_dims[1])));
+ get_desired_format(fwdParams.src_dims[1], is_2d)));
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
memory::format::any));
@@ -61,7 +62,7 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
// store expected primitive format
- context_.src_fmt = get_desired_format(fwdParams.src_dims[1]);
+ context_.src_fmt = get_desired_format(fwdParams.src_dims[1], is_2d);
context_.dst_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->dst_primitive_desc().desc().data.format);
@@ -126,12 +127,14 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
}
context_.alg_kind = bwdParams.alg_kind;
+ // check whether it is 2d or 3d
+ bool is_2d = (bwdParams.dst_dims.size() == 4);
// Create memory desc
context_.diff_src_md.reset(new memory::desc(
{bwdParams.src_dims}, MklDnnType<T>(), memory::format::any));
context_.diff_dst_md.reset(
new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(),
- get_desired_format(bwdParams.dst_dims[1])));
+ get_desired_format(bwdParams.dst_dims[1], is_2d)));
context_.bwd_desc.reset(new pooling_backward::desc(
bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
@@ -151,7 +154,7 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
// store expected primitive format
context_.diff_src_fmt = static_cast<mkldnn::memory::format>(
context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format);
- context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1]);
+ context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1], is_2d);
// create MKL-DNN internal memory object with dummy data
context_.diff_src_mem.reset(
@@ -165,7 +168,7 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
if (bwdParams.alg_kind == pooling_max) {
auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data;
context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
- context_.ws_fmt = get_desired_format(context_.ws_dims[1]);
+ context_.ws_fmt = get_desired_format(context_.ws_dims[1], is_2d);
context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
context_.ws_mem.reset(new memory(
{{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine},
@@ -211,13 +214,22 @@ void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format,
const TensorShape& tensor_in_shape) {
- // For maxpooling, tensor_in should have 4 dimensions.
- OP_REQUIRES(context, tensor_in_shape.dims() == 4,
- errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ // For maxpooling, tensor_in should have 4 or 5 dimensions.
+ OP_REQUIRES(context,
+ tensor_in_shape.dims() == 4 || tensor_in_shape.dims() == 5,
+ errors::InvalidArgument("tensor_in must be 4 or 5-dimensional"));
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
- tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
- tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
+ if (tensor_in_shape.dims() == 4) {
+ // Pool2D
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
+ } else {
+ // Pool3D
+ tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
+ }
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
Init(context, ksize, stride, padding, data_format);
@@ -246,10 +258,20 @@ void MklPoolParameters::Init(OpKernelContext* context,
TensorFormat data_format,
const MklDnnShape* mklInputShape) {
// Get the input sizes
- depth = mklInputShape->GetDimension('C');
- tensor_in_cols = mklInputShape->GetDimension('W');
- tensor_in_rows = mklInputShape->GetDimension('H');
- tensor_in_batch = mklInputShape->GetDimension('N');
+ if (ksize.size() == 4) {
+ // Pool2D
+ depth = mklInputShape->GetDimension('C');
+ tensor_in_cols = mklInputShape->GetDimension('W');
+ tensor_in_rows = mklInputShape->GetDimension('H');
+ tensor_in_batch = mklInputShape->GetDimension('N');
+ } else {
+ // Pool3D
+ depth = mklInputShape->GetDimension3D('C');
+ tensor_in_cols = mklInputShape->GetDimension3D('W');
+ tensor_in_rows = mklInputShape->GetDimension3D('H');
+ tensor_in_planes = mklInputShape->GetDimension3D('D');
+ tensor_in_batch = mklInputShape->GetDimension3D('N');
+ }
Init(context, ksize, stride, padding, data_format);
}
@@ -262,25 +284,58 @@ void MklPoolParameters::Init(OpKernelContext* context,
// Get the data format
this->data_format = data_format;
- // Get the output sizes
- window_rows = GetTensorDim(ksize, data_format, 'H');
- window_cols = GetTensorDim(ksize, data_format, 'W');
- depth_window = GetTensorDim(ksize, data_format, 'C');
-
- // Get the strides
- row_stride = GetTensorDim(stride, data_format, 'H');
- col_stride = GetTensorDim(stride, data_format, 'W');
- depth_stride = GetTensorDim(stride, data_format, 'C');
+ bool is_pool2d = (ksize.size() == 4);
+ if (is_pool2d) {
+ // Pool2D
+ // Get the output sizes
+ window_rows = GetTensorDim(ksize, data_format, 'H');
+ window_cols = GetTensorDim(ksize, data_format, 'W');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+
+ // Get the strides
+ row_stride = GetTensorDim(stride, data_format, 'H');
+ col_stride = GetTensorDim(stride, data_format, 'W');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 2D pooling across width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
+ errors::Unimplemented(
+ "MaxPooling supports exactly one of pooling across depth "
+ "or pooling across width/height."));
+ } else {
+ // Pool3D
+ // Get the output sizes
+ window_planes = GetTensorDim(ksize, data_format, '0');
+ window_rows = GetTensorDim(ksize, data_format, '1');
+ window_cols = GetTensorDim(ksize, data_format, '2');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+
+ // Get the strides
+ planes_stride = GetTensorDim(stride, data_format, '0');
+ row_stride = GetTensorDim(stride, data_format, '1');
+ col_stride = GetTensorDim(stride, data_format, '2');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 3D pooling across depth/width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 ||
+ (window_rows == 1 && window_cols == 1 && window_planes == 1)),
+ errors::Unimplemented(
+ "AvgPooling3D supports exactly one of pooling across depth "
+ "or pooling across depth/width/height."));
+ }
- // We only support 2D pooling across width/height and depthwise
- // pooling, not a combination.
- OP_REQUIRES(context,
- (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
- errors::Unimplemented(
- "MaxPooling supports exactly one of pooling across depth "
- "or pooling across width/height."));
+ if (depth_window == 1) { // we are pooling in the D (Pool3D only), H and W
+ if (!is_pool2d) {
+ OP_REQUIRES_OK(
+ context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes,
+ planes_stride, padding,
+ &out_planes, &pad_P1, &pad_P2));
+ }
- if (depth_window == 1) { // we are pooling in the H and W
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_rows, window_rows, row_stride,
padding, &out_height, &pad_top, &pad_bottom));
@@ -290,7 +345,14 @@ void MklPoolParameters::Init(OpKernelContext* context,
padding, &out_width, &pad_left, &pad_right));
#ifndef INTEL_MKL_ML_ONLY
// TF can work with int64, but mkldnn only supports int32
- // Fail if the height or width are greater than MAX_INT
+ // Fail if the depth, height or width are greater than MAX_INT
+ // We check depth only for 3D pooling case
+
+ if (!is_pool2d) {
+ OP_REQUIRES(context,
+ FastBoundsCheck(out_planes, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("output depth/planes is too large"));
+ }
OP_REQUIRES(context,
FastBoundsCheck(out_height, std::numeric_limits<int>::max()),
@@ -299,7 +361,6 @@ void MklPoolParameters::Init(OpKernelContext* context,
OP_REQUIRES(context,
FastBoundsCheck(out_width, std::numeric_limits<int>::max()),
errors::InvalidArgument("output width is too large"));
-
#endif
out_depth = depth; // output will have the same depth as the input
} else { // we are pooling in the depth dimension
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index ec7af5092d..49f799d7ba 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -19,6 +19,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include <memory>
#include <vector>
+#include <string>
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
@@ -32,7 +33,7 @@ using mkldnn::stream;
namespace tensorflow {
-#ifndef INTEL_MKL_ML
+#ifndef INTEL_MKL_ML_ONLY
using mkldnn::memory;
using mkldnn::pooling_avg;
@@ -357,22 +358,28 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklPoolParameters {
int depth;
+ int tensor_in_planes; // Pool3D
int tensor_in_cols;
int tensor_in_rows;
int tensor_in_batch;
+ int window_planes; // Pool3D
int window_rows;
int window_cols;
int depth_window;
+ int planes_stride; // Pool3D
int row_stride;
int col_stride;
int depth_stride;
+ int64 out_planes; // Pool3D
int64 out_height;
int64 out_width;
int out_depth;
+ int64 pad_P1; // Pool3D
+ int64 pad_P2; // Pool3D
int64 pad_left;
int64 pad_right;
int64 pad_top;
@@ -382,18 +389,24 @@ struct MklPoolParameters {
TensorFormat data_format;
MklPoolParameters()
: depth(0),
+ tensor_in_planes(0),
tensor_in_cols(0),
tensor_in_rows(0),
tensor_in_batch(0),
+ window_planes(0),
window_rows(0),
window_cols(0),
depth_window(0),
+ planes_stride(0),
row_stride(0),
col_stride(0),
depth_stride(0),
+ out_planes(0),
out_height(0),
out_width(0),
out_depth(0),
+ pad_P1(0),
+ pad_P2(0),
pad_left(0),
pad_right(0),
pad_top(0),
@@ -433,20 +446,22 @@ class MklPoolingOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_),
errors::InvalidArgument("Invalid data format"));
- this->data_format_mkldnn_ =
- TFDataFormatToMklDnnDataFormat(this->data_format_tf_);
OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_));
- OP_REQUIRES(context, this->ksize_.size() == 4,
+ OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5,
errors::InvalidArgument("Sliding window ksize field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_));
- OP_REQUIRES(context, this->stride_.size() == 4,
+ OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5,
errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_));
OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1,
errors::Unimplemented("Pooling is not yet supported on the "
"batch dimension."));
+ bool is_pool2d = (this->ksize_.size() == 4);
+ this->data_format_mkldnn_ =
+ is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_);
// We may not get this attribute for this node if it does not go through
// graph rewrite pass. So we do not check for error while retrieving this
@@ -457,17 +472,26 @@ class MklPoolingOpBase : public OpKernel {
protected:
// Calculate output shape of pooling op in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function expects
- // output height and output width to have already been int32
- // bounds-checked
+ // MKL-DNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order.
+ // But TensorFlow output will be in NHWC/NCHW(Pool2D) or
+ // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects
+ // output height and width to have already been int32 bounds-checked.
void GetOutputDims(const MklPoolParameters& mkl_pool_params,
memory::dims* output_dims_mkl_order) {
- // MKL-DNN always needs output in NCHW format.
- *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
- mkl_pool_params.out_depth,
- static_cast<int>(mkl_pool_params.out_height),
- static_cast<int>(mkl_pool_params.out_width)};
+ if (this->ksize_.size() == 4) {
+ // Pooling2D: MKL-DNN always needs output in NCHW format.
+ *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
+ mkl_pool_params.out_depth,
+ static_cast<int>(mkl_pool_params.out_height),
+ static_cast<int>(mkl_pool_params.out_width)};
+ } else {
+ // Pooling3D: MKL-DNN always needs output in NCDHW format.
+ *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
+ mkl_pool_params.out_depth,
+ static_cast<int>(mkl_pool_params.out_planes),
+ static_cast<int>(mkl_pool_params.out_height),
+ static_cast<int>(mkl_pool_params.out_width)};
+ }
}
void InitMklPoolParameters(OpKernelContext* context,
@@ -485,14 +509,34 @@ class MklPoolingOpBase : public OpKernel {
void PoolParamsToDims(const MklPoolParameters* pool_params,
memory::dims* filter_dims, memory::dims* strides,
- memory::dims* padding_left,
- memory::dims* padding_right) {
- *filter_dims = {pool_params->window_rows, pool_params->window_cols};
- *strides = {pool_params->row_stride, pool_params->col_stride};
- *padding_left = {static_cast<int>(pool_params->pad_top),
- static_cast<int>(pool_params->pad_left)};
- *padding_right = {static_cast<int>(pool_params->pad_bottom),
- static_cast<int>(pool_params->pad_right)};
+ memory::dims* padding_left, memory::dims* padding_right,
+ bool is_pool2d) {
+ if (is_pool2d) {
+ // Pool2D
+ *filter_dims =
+ memory::dims({pool_params->window_rows, pool_params->window_cols});
+ *strides =
+ memory::dims({pool_params->row_stride, pool_params->col_stride});
+ *padding_left = memory::dims({static_cast<int>(pool_params->pad_top),
+ static_cast<int>(pool_params->pad_left)});
+ *padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom),
+ static_cast<int>(pool_params->pad_right)});
+ } else {
+ // Pool3D
+ *filter_dims =
+ memory::dims({pool_params->window_planes, pool_params->window_rows,
+ pool_params->window_cols});
+ *strides =
+ memory::dims({pool_params->planes_stride, pool_params->row_stride,
+ pool_params->col_stride});
+
+ *padding_left = memory::dims({static_cast<int>(pool_params->pad_P1),
+ static_cast<int>(pool_params->pad_top),
+ static_cast<int>(pool_params->pad_left)});
+ *padding_right = memory::dims({static_cast<int>(pool_params->pad_P2),
+ static_cast<int>(pool_params->pad_bottom),
+ static_cast<int>(pool_params->pad_right)});
+ }
}
void AllocateEmptyOutputTensor(OpKernelContext* context,
@@ -556,12 +600,27 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
TensorShape input_tensor_shape = input_tensor.shape();
if (input_tensor.NumElements() != 0) {
memory::desc input_md =
- input_mkl_shape.IsMklTensor()
- ? input_mkl_shape.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetMklLayout()
+ : memory::desc(
+ (this->ksize_.size() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape,
this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
+ MklDnnType<T>(), this->data_format_mkldnn_);
dnn_data_input->SetUsrMem(input_md, &input_tensor);
+
+ if (this->ksize_.size() == 5) {
+ // Pool3D
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4];
+ dnn_data_input->SetOpMemDesc(mkldnn_sizes, this->data_format_mkldnn_);
+ }
}
this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
input_tensor_shape);
@@ -593,12 +652,13 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
const MklDnnShape& input_mkl_shape) {
if (!input_mkl_shape.IsMklTensor()) {
- OP_REQUIRES(context, input_tensor.dims() == 4,
- errors::InvalidArgument("Input must be 4-dimensional"));
+ OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5,
+ errors::InvalidArgument("Input must be 4 or 5-dimensional"));
} else {
- OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4,
+ OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4 ||
+ input_mkl_shape.GetDimension() == 5,
errors::InvalidArgument("Input shape must be "
- "4-dimensional"));
+ "4 or 5-dimensional"));
}
}
// .Input("value: T")
@@ -649,8 +709,12 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
input_gradient_mkl_shape.IsMklTensor()
? input_gradient_mkl_shape.GetMklLayout()
: memory::desc(
- TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
- this->data_format_tf_),
+ (this->ksize_.size() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(
+ input_gradient_tensor.shape(),
+ this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);
input_gradient_dnn_data->SetUsrMem(original_input_grad_md,
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 94476acd4b..658e116ac8 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -2026,6 +2026,104 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklAvgPool3D")
+ .Input("value: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn(shape_inference::Pool3DShape)
+ .Doc(R"doc(
+MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+
+REGISTER_OP("_MklAvgPool3DGrad")
+ .Input("orig_input_shape: int32")
+ .Input("grad: T")
+ .Input("mkl_orig_input: uint8")
+ .Input("mkl_grad: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of AvgPool3DGrad operator. Uses MKL DNN APIs to compute gradients
+of AvgPool function.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklMaxPool3D")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("workspace: uint8")
+ .Output("mkl_output: uint8")
+ .Output("mkl_workspace: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {half, bfloat16, float}")
+ .Attr("workspace_enabled: bool = false")
+ .SetShapeFn(shape_inference::Pool3DShape)
+ .Doc(R"doc(
+MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklMaxPool3DGrad")
+ .Input("orig_input: TInput")
+ .Input("orig_output: TInput")
+ .Input("grad: T")
+ .Input("workspace: uint8")
+ .Input("mkl_orig_input: uint8")
+ .Input("mkl_orig_output: uint8")
+ .Input("mkl_grad: uint8")
+ .Input("mkl_workspace: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {half, bfloat16, float} = DT_FLOAT")
+ .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
+ .Attr("workspace_enabled: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 5);
+ })
+ .Doc(R"doc(
+MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients
+of MklPool function.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
REGISTER_OP("_MklLRN")
.Input("input: T")
.Input("mkl_input: uint8")
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 422be9356d..0a96a603d0 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -66,7 +66,6 @@ using mkldnn::reorder;
typedef unsigned int uint;
#endif
-
namespace tensorflow {
// The file contains a number of utility classes and functions used by MKL
@@ -645,6 +644,7 @@ class MklDnnShape {
}
}
+
inline void SetTfDimOrder(const size_t dimension, memory::format format) {
TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
SetTfDimOrder(dimension, data_format);
@@ -2059,16 +2059,20 @@ class FactoryKeyCreator {
}
};
-static inline memory::format get_desired_format(int channel) {
+
+static inline memory::format get_desired_format(int channel,
+ bool is_2d = true) {
memory::format fmt_desired = memory::format::any;
- if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) {
- fmt_desired = memory::format::nChw16c;
+ if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
+ fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = memory::format::nChw8c;
+ fmt_desired = is_2d
+ ? memory::format::nChw8c
+ : memory::format::ncdhw; //not support avx2 for 3d yet.
} else {
- fmt_desired = memory::format::nchw;
+ fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
return fmt_desired;
}