aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 13:04:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 13:05:05 -0700
commit9c50882415cb87a7eb81048d42401c64bf0617ef (patch)
treec550925b2d9e7f6997ace0e3bb3268572f7066b7 /tensorflow/core
parent19cafed2ae69ce5cbc4d2b2fc9176fb4c550040f (diff)
parent62191da0819b25906c1b2ed96159cfe36ba00383 (diff)
Merge pull request #21324 from Intel-tensorflow:conv3d
PiperOrigin-RevId: 209032082
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc39
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc174
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc144
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc157
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h414
-rw-r--r--tensorflow/core/ops/nn_ops.cc85
-rw-r--r--tensorflow/core/util/mkl_util.h103
-rw-r--r--tensorflow/core/util/tensor_format.cc4
-rw-r--r--tensorflow/core/util/tensor_format.h1
9 files changed, 756 insertions, 365 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 5683944e46..833592caab 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2418,6 +2418,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.conv2d_grad_filter_with_bias =
"__MklDummyConv2DBackpropFilterWithBias";
+ csinfo_.conv3d = "Conv3D";
+ csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
+ csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.identity = "Identity";
@@ -2468,18 +2471,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsConcatV2, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
- csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
+ csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_filter,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_input,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite});
@@ -2614,6 +2626,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string conv2d_grad_input;
string conv2d_grad_filter;
string conv2d_grad_filter_with_bias;
+ string conv3d;
+ string conv3d_grad_input;
+ string conv3d_grad_filter;
string fused_batch_norm;
string fused_batch_norm_grad;
string identity;
@@ -3086,7 +3101,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
@@ -3571,14 +3586,13 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
// Op-specific functions to copy attributes from old node to new node
//////////////////////////////////////////////////////////////////////////
-void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
- NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
string data_format;
string padding;
std::vector<int32> strides;
std::vector<int32> dilations;
- bool use_cudnn_on_gpu;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
@@ -3586,8 +3600,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
- TF_CHECK_OK(
- GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
// Add attributes to new node.
nb->Attr("T", T);
@@ -3595,7 +3607,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
nb->Attr("dilations", dilations);
nb->Attr("padding", padding);
nb->Attr("data_format", data_format);
- nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
@@ -3896,7 +3907,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
// Copy attributes from Conv2D to Conv2DWithBias.
- CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
+ CopyAttrsConv(const_cast<const Node*>(pred), &nb);
// Copy the device assigned to old node to new node.
nb.Device(succ->def().device());
@@ -4007,7 +4018,7 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
}
// Copy attributes from Conv2DBackpropFilter.
- CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb);
+ CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
// Copy the device assigned to old node to new node.
nb.Device(fltr->def().device());
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 50c25e1da7..afbfaa83f3 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -82,11 +82,11 @@ struct MklConvBwdFilterParams {
};
template <typename T>
-class MklConv2DBwdFilterPrimitive : public MklPrimitive {
+class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdFilterPrimitive(
- const MklConvBwdFilterParams& convBwdFilterDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdFilterPrimitive(
+ const MklConvBwdFilterParams& convBwdFilterDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
// create conv primitive
if (context_.conv_bwd_filter == nullptr) {
@@ -94,7 +94,7 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
}
}
- ~MklConv2DBwdFilterPrimitive() {}
+ ~MklConvBwdFilterPrimitive() {}
// Convolution backward weights with bias
// src_data: input data buffer of src
@@ -297,38 +297,36 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DBwdFilterPrimitive<T>* Get(
+ static MklConvBwdFilterPrimitive<T>* Get(
const MklConvBwdFilterParams& convBwdFilterDims) {
- MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
// look into the pool for reusable primitive
- conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> (
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter(
- convBwdFilterDims));
-
- if (conv2d_bwd_filter == nullptr) {
- conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>(
- convBwdFilterDims);
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter(
- convBwdFilterDims, conv2d_bwd_filter);
+ conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
+ convBwdFilterDims));
+
+ if (conv_bwd_filter == nullptr) {
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
+ convBwdFilterDims, conv_bwd_filter);
}
- return conv2d_bwd_filter;
+ return conv_bwd_filter;
}
-
private:
- MklConv2DBwdFilterPrimitiveFactory() {}
- ~MklConv2DBwdFilterPrimitiveFactory() {}
+ MklConvBwdFilterPrimitiveFactory() {}
+ ~MklConvBwdFilterPrimitiveFactory() {}
- static MklConv2DBwdFilterPrimitiveFactory& GetInstance() {
- static MklConv2DBwdFilterPrimitiveFactory instance_;
+ static MklConvBwdFilterPrimitiveFactory& GetInstance() {
+ static MklConvBwdFilterPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
- string prefix = "conv2d_bwd_filter";
+ string prefix = "conv_bwd_filter";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdFilterDims.src_dims);
@@ -342,14 +340,14 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdFilter(
+ MklPrimitive* GetConvBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims) {
string key = CreateKey(convBwdFilterDims);
return this->GetOp(key);
}
- void SetConv2dBwdFilter(
- const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
+ void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdFilterDims);
this->SetOp(key, op);
}
@@ -738,14 +736,13 @@ TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#else
template <typename Device, class T, bool biasEnabled>
-class MklConv2DCustomBackpropFilterOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropFilterOp
+ : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropFilterOp() {}
+ ~MklConvCustomBackpropFilterOp() {}
void Compute(OpKernelContext* context) {
try {
@@ -753,6 +750,9 @@ class MklConv2DCustomBackpropFilterOp
MklDnnData<T> diff_dst(&cpu_engine_);
MklDnnData<T> diff_filter(&cpu_engine_); // output
+ // This flag indicates Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -813,7 +813,10 @@ class MklConv2DCustomBackpropFilterOp
&fwd_dst_dims, &padding_left, &padding_right);
if (!context->status().ok()) return;
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
+
auto fwd_src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
@@ -832,21 +835,19 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
TensorShape obp_tf_shape = GetTfShape(context, 2);
depth = (this->data_format_ == FORMAT_NCHW)
- ? obp_tf_shape.dim_size(1)
- : obp_tf_shape.dim_size(3);
+ ? obp_tf_shape.dim_size(1)
+ : obp_tf_shape.dim_size(isConv2D ? 3 : 4);
diff_bias_dims = {static_cast<int>(depth)};
}
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
padding_right, TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get(
- convBwdFilterDims);
- auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc();
+ conv_bwd_filter =
+ MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims);
+ auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
@@ -854,14 +855,26 @@ class MklConv2DCustomBackpropFilterOp
// diff_filter
MklDnnShape diff_filter_mkl_shape;
diff_filter_mkl_shape.SetMklTensor(false);
- // output_dims_mkl_order is in OIHW format.
- TensorShape diff_filter_tf_shape(
- {bwd_output_dims[MklDnnDims::Dim_H],
- bwd_output_dims[MklDnnDims::Dim_W],
- bwd_output_dims[MklDnnDims::Dim_I],
- bwd_output_dims[MklDnnDims::Dim_O]});
- AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
- diff_filter_tf_shape, diff_filter_mkl_shape);
+
+ if (isConv2D) {
+ // Conv2D: output_dims_mkl_order is in OIHW format.
+ TensorShape diff_filter_tf_shape({bwd_output_dims[MklDnnDims::Dim_H],
+ bwd_output_dims[MklDnnDims::Dim_W],
+ bwd_output_dims[MklDnnDims::Dim_I],
+ bwd_output_dims[MklDnnDims::Dim_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ } else {
+ // Conv3D: output_dims_mkl_order is in OIDHW format.
+ TensorShape diff_filter_tf_shape(
+ {bwd_output_dims[MklDnnDims3D::Dim3d_D],
+ bwd_output_dims[MklDnnDims3D::Dim3d_H],
+ bwd_output_dims[MklDnnDims3D::Dim3d_W],
+ bwd_output_dims[MklDnnDims3D::Dim3d_I],
+ bwd_output_dims[MklDnnDims3D::Dim3d_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ }
Tensor* diff_bias_tensor = nullptr;
if (biasEnabled) {
@@ -871,7 +884,7 @@ class MklConv2DCustomBackpropFilterOp
// check if src and diff_dst need reorder
T *src_data = nullptr;
- if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) {
+ if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -882,7 +895,7 @@ class MklConv2DCustomBackpropFilterOp
T *diff_dst_data = nullptr;
if (diff_dst_md.data.format !=
- conv2d_bwd_filter->GetDiffDstMemoryFormat()) {
+ conv_bwd_filter->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -897,7 +910,7 @@ class MklConv2DCustomBackpropFilterOp
bool diff_filter_reorder_required = false;
T *diff_filter_data = nullptr;
if (GetOutputFormat(tf_fmt) !=
- conv2d_bwd_filter->GetDiffFilterMemoryFormat()) {
+ conv_bwd_filter->GetDiffFilterMemoryFormat()) {
// Allocate diff filter tensor as Tensorflow layout
diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
diff_filter_tensor);
@@ -915,10 +928,10 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
T* diff_bias_data = static_cast<T*>(const_cast<T*>(
diff_bias_tensor->flat<T>().data()));
- conv2d_bwd_filter->Execute(src_data, diff_filter_data,
- diff_bias_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
+ diff_dst_data);
} else {
- conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
}
// Reorder diff_filter back to Tensorflow layout if necessary
@@ -947,7 +960,7 @@ class MklConv2DCustomBackpropFilterOp
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
CHECK(!filter_mkl_shape.IsMklTensor())
- << "Conv2DBackpropFilter: filter should not be in MKL Layout";
+ << "ConvBackpropFilter: filter should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -983,9 +996,11 @@ class MklConv2DCustomBackpropFilterOp
return fwd_filter_dims;
}
- // Output layout is Tensorflow's filter layout (HWIO).
+ // Output layout is Tensorflow's filter layout
+ // Conv2D: HWIO; Conv3D: DHWIO
memory::format GetOutputFormat(const memory::format data_format) {
- return memory::format::hwio;
+ return (this->strides_.size() == 4) ? memory::format::hwio
+ : memory::format::dhwio;
}
// Allocate output tensor.
@@ -1027,24 +1042,27 @@ class MklConv2DCustomBackpropFilterOp
}
};
-#define REGISTER_MKL_FILTER_KERNELS(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilter") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>); \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>); \
- REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklDummyOp<CPUDevice, T>);
+#define REGISTER_MKL_FILTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, true>); \
+ REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDummyOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>);
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 38e014d68e..b5a98301e2 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#ifndef INTEL_MKL_ML_ONLY
-/// utility classes enabling primitive reuse for backward conv2d ops.
+/// utility classes enabling primitive reuse for backward conv ops.
struct MklConvBwdInputParams {
memory::dims diff_src_dims;
memory::dims filter_dims;
@@ -83,11 +83,11 @@ struct MklConvBwdInputParams {
};
template <typename T>
-class MklConv2DBwdInputPrimitive : public MklPrimitive {
+class MklConvBwdInputPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdInputPrimitive(
- const MklConvBwdInputParams& convBwdInputDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdInputPrimitive(
+ const MklConvBwdInputParams& convBwdInputDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_input_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -95,7 +95,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
Setup(convBwdInputDims);
}
}
- ~MklConv2DBwdInputPrimitive() {}
+ ~MklConvBwdInputPrimitive() {}
// Convolution backward filter (weights)
// diff_src_data: output data buffer of diff_src
@@ -134,7 +134,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
private:
- // Primitive reuse context for Conv2D Bwd Input op
+ // Primitive reuse context for Conv Bwd Input op
struct ConvBwdInputContext {
// expected memory format for this primitive instance
memory::format filter_fmt;
@@ -235,38 +235,37 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
private:
- MklConv2DBwdInputPrimitiveFactory() {}
- ~MklConv2DBwdInputPrimitiveFactory() {}
+ MklConvBwdInputPrimitiveFactory() {}
+ ~MklConvBwdInputPrimitiveFactory() {}
public:
- static MklConv2DBwdInputPrimitive<T>* Get(
+ static MklConvBwdInputPrimitive<T>* Get(
const MklConvBwdInputParams& convBwdInputDims) {
- MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr;
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
// look into the pool for reusable primitive
- conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> (
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput(
+ conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
convBwdInputDims));
- if (conv2d_bwd_input == nullptr) {
- conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>(
- convBwdInputDims);
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput(
- convBwdInputDims, conv2d_bwd_input);
+ if (conv_bwd_input == nullptr) {
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
+ convBwdInputDims, conv_bwd_input);
}
- return conv2d_bwd_input;
+ return conv_bwd_input;
}
private:
- static MklConv2DBwdInputPrimitiveFactory& GetInstance() {
- static MklConv2DBwdInputPrimitiveFactory instance_;
+ static MklConvBwdInputPrimitiveFactory& GetInstance() {
+ static MklConvBwdInputPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) {
- string prefix = "conv2d_bwd_input";
+ string prefix = "conv_bwd_input";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
@@ -279,14 +278,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims) {
+ MklPrimitive* GetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims) {
string key = CreateKey(convBwdInputDims);
return this->GetOp(key);
}
- void SetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
+ void SetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdInputDims);
this->SetOp(key, op);
}
@@ -594,23 +592,34 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
+#undef REGISTER_MKL_CPU_KERNELS
+
#else
template <typename Device, class T>
-class MklConv2DCustomBackpropInputOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropInputOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropInputOp() {}
+ ~MklConvCustomBackpropInputOp() {}
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
+ // This flag indicate Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -626,7 +635,7 @@ class MklConv2DCustomBackpropInputOp
diff_dst_mkl_shape);
// Allow operator-specific generation of shapes.
- // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // E.g., ConvBackpropFilter gets filter as filter_sizes. It is a
// tensor containing shape of filter. So filter.shape() is not
// a correct way to get filter shape. These operator-specific calls
// allow this class to handle this case.
@@ -655,6 +664,7 @@ class MklConv2DCustomBackpropInputOp
}
return;
}
+
// By default, all dims are in MKL order. Only dims in TF order
// are those with postfix tf_order.
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
@@ -673,15 +683,18 @@ class MklConv2DCustomBackpropInputOp
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
// If filter is in MKL layout, then simply grab filter layout;
// otherwise, construct filter in TF layout.
// For TF layout, filter is in HWIO format.
auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
- ? filter_mkl_shape.GetMklLayout()
- : memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
+ ? filter_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
@@ -689,18 +702,15 @@ class MklConv2DCustomBackpropInputOp
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims,
MklDnnType<T>(), tf_fmt);
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr;
- conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
diff_dst_dims, strides, dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get(
- convBwdInputDims);
- auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc();
+ conv_bwd_input =
+ MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims);
+ auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
@@ -723,7 +733,7 @@ class MklConv2DCustomBackpropInputOp
// check if filter and diff_dst need reorder
T* filter_data = nullptr;
if (fwd_filter_md.data.format !=
- conv2d_bwd_input->GetFilterMemoryFormat()) {
+ conv_bwd_input->GetFilterMemoryFormat()) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
@@ -733,8 +743,7 @@ class MklConv2DCustomBackpropInputOp
}
T* diff_dst_data = nullptr;
- if (diff_dst_md.data.format !=
- conv2d_bwd_input->GetDiffDstMemoryFormat()) {
+ if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -745,7 +754,7 @@ class MklConv2DCustomBackpropInputOp
}
// execute convolution input bwd
- conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+ conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -770,7 +779,7 @@ class MklConv2DCustomBackpropInputOp
// of the Tensor and never an actual tensor. So it will never be in MKL
// layout.
CHECK(!input_mkl_shape.IsMklTensor())
- << "Conv2DBackpropInput: input should not be in MKL Layout";
+ << "ConvBackpropInput: input should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -778,10 +787,10 @@ class MklConv2DCustomBackpropInputOp
const Tensor& input_tensor) {
TensorShape input_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
- CHECK_EQ(
- TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape)
- .ok(),
- true);
+ // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64
+ // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for
+ // input_tensor.
+ CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true);
return input_tf_shape;
}
@@ -792,7 +801,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
@@ -800,7 +809,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
return fwd_input_dims;
@@ -839,17 +848,22 @@ class MklConv2DCustomBackpropInputOp
}
};
-#endif // INTEL_MKL_ML_ONLY
-
-#define REGISTER_MKL_CPU_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
#undef REGISTER_MKL_CPU_KERNELS
+#endif // INTEL_MKL_ML_ONLY
+
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index bca1aa21a8..c6295c7280 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -85,9 +85,9 @@ struct MklConvFwdParams {
};
template <typename T>
-class MklConv2DFwdPrimitive : public MklPrimitive {
+class MklConvFwdPrimitive : public MklPrimitive {
public:
- explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
+ explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -96,7 +96,7 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
}
}
- ~MklConv2DFwdPrimitive() {}
+ ~MklConvFwdPrimitive() {}
// Convolution forward execute with bias
// src_data: input data buffer of src
@@ -269,37 +269,36 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+ static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
// try to find a suitable one in pool
- conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
- convFwdDims));
-
- if (conv2d_fwd == nullptr) {
- conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
- conv2d_fwd);
+ conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
+ MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
+
+ if (conv_fwd == nullptr) {
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
+ conv_fwd);
}
- return conv2d_fwd;
+ return conv_fwd;
}
private:
- MklConv2DFwdPrimitiveFactory() {}
- ~MklConv2DFwdPrimitiveFactory() {}
+ MklConvFwdPrimitiveFactory() {}
+ ~MklConvFwdPrimitiveFactory() {}
static const int kDilationH = 0, kDilationW = 1;
- static MklConv2DFwdPrimitiveFactory& GetInstance() {
- static MklConv2DFwdPrimitiveFactory instance_;
+ static MklConvFwdPrimitiveFactory& GetInstance() {
+ static MklConvFwdPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvFwdParams& convFwdDims) {
- string prefix = "conv2d_fwd_";
+ string prefix = "conv_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convFwdDims.src_dims);
@@ -313,12 +312,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
+ MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) {
string key = CreateKey(convFwdDims);
return this->GetOp(key);
}
- void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
+ void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
@@ -331,11 +330,11 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// For now, MKL-ML is default. So making MKL-DNN not a default choice.
#ifdef INTEL_MKL_ML_ONLY
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
@@ -755,21 +754,22 @@ class MklConv2DOp : public OpKernel {
#else
+// Base class for convolution forward operations
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
- OP_REQUIRES(context, strides_.size() == 4,
+ OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5),
errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
@@ -778,20 +778,39 @@ class MklConv2DOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ } else if (strides_.size() == 5) {
+ OP_REQUIRES(context, dilations_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
+ GetTensorDim(dilations_, data_format_, 'C') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations rates in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(dilations_, data_format_, '0') > 0 &&
+ GetTensorDim(dilations_, data_format_, '1') > 0 &&
+ GetTensorDim(dilations_, data_format_, '2') > 0),
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
}
void Compute(OpKernelContext* context) override {
@@ -837,7 +856,8 @@ class MklConv2DOp : public OpKernel {
AllocateOutputSetMklShape(context, kOutputIndex_Dst,
&dst_tensor, src_tf_shape, dst_mkl_shape);
- // MklConv2D also outputs converted filter as 2nd output of Conv2D.
+ // MklConv2D/3D also outputs converted filter
+ // as 2nd output of Conv2D/3D.
filter_mkl_shape.SetMklTensor(false);
Tensor* output_filter_tensor = nullptr;
AllocateOutputSetMklShape(context, kOutputIndex_Filter,
@@ -846,15 +866,20 @@ class MklConv2DOp : public OpKernel {
return;
}
+ bool isConv2D = (strides_.size() == 4);
+
// Create memory for user data.
// Describe how the inputs and outputs of Convolution look like. Also
// specify buffers containing actual input and output data.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto tf_fmt = isConv2D ? TFDataFormatToMklDnnDataFormat(data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(data_format_);
// If input is in MKL layout, then simply grab input layout; otherwise,
// construct input Tf layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
- // layout (NHWC or NCHW depending on data format).
+ // layout depending on data format:
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
@@ -864,31 +889,30 @@ class MklConv2DOp : public OpKernel {
auto filter_md = filter_mkl_shape.IsMklTensor() // Should NEVER be true
? filter_mkl_shape.GetMklLayout()
: memory::desc(filter_dims, MklDnnType<T>(),
- memory::format::hwio);
-
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
// MKLDNN dilation starts from 0.
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
// get a conv2d fwd from primitive pool
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
if (biasEnabled) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
}
// allocate output tensors output_tensor and filter_out_tensor
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
- conv2d_fwd->GetPrimitiveDesc();
+ conv_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr;
@@ -900,7 +924,7 @@ class MklConv2DOp : public OpKernel {
// check whether src/filter need reorder
T *src_data = nullptr;
- if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
+ if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -908,7 +932,7 @@ class MklConv2DOp : public OpKernel {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
T* filter_data = nullptr;
- if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
+ if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
filter.GetTensorBuffer(filter_out_tensor));
@@ -918,16 +942,15 @@ class MklConv2DOp : public OpKernel {
static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
}
-
// execute convolution
if (biasEnabled) {
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
T* bias_data = static_cast<T*>(const_cast<T*>(
bias_tensor.flat<T>().data()));
- conv2d_fwd->Execute(src_data, filter_data, bias_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
} else {
- conv2d_fwd->Execute(src_data, filter_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, dst_data);
}
} catch (mkldnn::error &e) {
string error_msg = tensorflow::strings::StrCat(
@@ -1038,17 +1061,18 @@ class MklConv2DOp : public OpKernel {
#endif
+// Register 2D operations
#define REGISTER_MKL_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, false>); \
+ MklConvOp<CPUDevice, T, false>); \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, true>); \
+ MklConvOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
@@ -1057,5 +1081,14 @@ class MklConv2DOp : public OpKernel {
TF_CALL_float(REGISTER_MKL_CPU);
+// Register 3D operations
+#define REGISTER_MKL_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvOp<CPUDevice, T, false>);
+TF_CALL_float(REGISTER_MKL_CPU);
+
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 838c06f49d..01cc606f41 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -79,9 +79,16 @@ class MklDnnConvUtil {
// For now we take the stride from the second and third dimensions only
// (we do not support striding on the batch or depth dimension).
CHECK_NOTNULL(strides);
- int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- int stride_cols = GetTensorDim(strides_, data_format_, 'W');
- *strides = {stride_rows, stride_cols};
+ if (strides_.size() == 4) {
+ int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ *strides = {stride_rows, stride_cols};
+ } else if (strides_.size() == 5) {
+ int stride_planes = GetTensorDim(strides_, data_format_, '0');
+ int stride_rows = GetTensorDim(strides_, data_format_, '1');
+ int stride_cols = GetTensorDim(strides_, data_format_, '2');
+ *strides = {stride_planes, stride_rows, stride_cols};
+ }
}
// Calculate Convolution dilations
@@ -89,13 +96,20 @@ class MklDnnConvUtil {
// For now we take the dilation from the second and third dimensions only
// (we do not support dilation on the batch or depth dimension).
CHECK_NOTNULL(dilations);
- int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
- int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
- *dilations = {dilations_rows, dilations_cols};
+ if (dilations_.size() == 4) {
+ int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
+ *dilations = {dilations_rows, dilations_cols};
+ } else if (dilations_.size() == 5) {
+ int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
+ int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
+ *dilations = {dilations_planes, dilations_rows, dilations_cols};
+ }
}
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
- // requires input in NCHW format. Function does not return anything.
+ // requires input in NCHW/NCDHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
@@ -113,40 +127,62 @@ class MklDnnConvUtil {
int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
int input_depth = static_cast<int>(input_depth_raw);
- // Input rows/height
- int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
- CHECK_BOUNDS(input_rows_raw, "Input rows too large");
- int input_rows = static_cast<int>(input_rows_raw);
-
- // Input columns/width
- int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
- CHECK_BOUNDS(input_cols_raw, "Input cols too large");
- int input_cols = static_cast<int>(input_cols_raw);
-
// Input batch
int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
int input_batch = static_cast<int>(input_batch_raw);
+ if (strides_.size() == 4) { // NCHW format for Conv2D
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCHW format Conv2D.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ } else if (strides_.size() == 5) { // NCDHW format for Conv3D
+ // Input planes/third-dimension
+ int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
+ CHECK_BOUNDS(input_planes_raw, "Input depth too large");
+ int input_planes = static_cast<int>(input_planes_raw);
+
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCDHW format for Conv3D.
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ }
#undef CHECK_BOUNDS
-
- // MKL-DNN always requires input in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
-
- *input_dims = mkldnn_sizes;
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
- //
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
+ // Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status. This function differs from GetConvFilterSizeInMklOrder in
// parameter for input - it accepts src_shape since Convolution Backward
@@ -159,11 +195,13 @@ class MklDnnConvUtil {
memory::dims* filter_dims) {
CHECK_NOTNULL(filter_dims);
- OP_REQUIRES(context_, filter_shape.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
+ OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
+ errors::InvalidArgument((strides_.size() == 4)
+ ? "filter must be 4-dimensional: "
+ : "filter must be 5-dimensional: ",
filter_shape.DebugString()));
- for (int i = 0; i < 3; i++) {
+ for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
OP_REQUIRES(context_,
FastBoundsCheck(filter_shape.dim_size(i),
std::numeric_limits<int>::max()),
@@ -172,32 +210,57 @@ class MklDnnConvUtil {
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
- OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", input_depth,
- " vs ", filter_shape.dim_size(2)));
-
- // TF filter is always in (rows, cols, in_depth, out_depth) order.
- int filter_rows = static_cast<int>(filter_shape.dim_size(0));
- int filter_cols = static_cast<int>(filter_shape.dim_size(1));
- int in_depth = static_cast<int>(filter_shape.dim_size(2));
- int out_depth = static_cast<int>(filter_shape.dim_size(3));
-
- // MKL-DNN always needs filter in OIHW format.
- // OIHW = (out_depth, in_depth, rows, cols)
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
-
- *filter_dims = mkldnn_sizes;
+ if (strides_.size() == 4) { // Conv2D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(2)));
+
+ // TF filter is always in (rows, cols, in_depth, out_depth) order.
+ int filter_rows = static_cast<int>(filter_shape.dim_size(0));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(1));
+ int in_depth = static_cast<int>(filter_shape.dim_size(2));
+ int out_depth = static_cast<int>(filter_shape.dim_size(3));
+
+ // MKL-DNN always needs filter in OIHW format.
+ // OIHW = (out_depth, in_depth, rows, cols)
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ } else { // Conv3D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(3)));
+
+ // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
+ int filter_planes = static_cast<int>(filter_shape.dim_size(0));
+ int filter_rows = static_cast<int>(filter_shape.dim_size(1));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(2));
+ int in_depth = static_cast<int>(filter_shape.dim_size(3));
+ int out_depth = static_cast<int>(filter_shape.dim_size(4));
+
+ // MKL-DNN always needs filter in OIDHW format.
+ // OIDHW = (out_depth, in_depth, planes, rows, cols)
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_O] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_I] = in_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ }
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
+ // Function does not return anything. But errors arising from sanity
+ // checks are returned in context's status.
virtual inline void GetFilterSizeInMklOrder(size_t src_index,
size_t filter_index,
memory::dims* filter_dims) {
@@ -206,8 +269,8 @@ class MklDnnConvUtil {
GetTfShape(context_, filter_index), filter_dims);
}
- // Calculate Bias size for 2D Convolution. Function does not return
- // anything, but sets error in context status.
+ // Calculate Bias size for 2D or 3D Convolution. Function does not
+ // return anything, but may set an error in context status.
virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
memory::dims* bias_dims) {
const Tensor& bias = MklGetInput(context_, bias_index);
@@ -218,73 +281,142 @@ class MklDnnConvUtil {
*bias_dims = {static_cast<int>(bias.dim_size(0))};
}
- // Function to calculate output and padding size for 2D convolution.
+ // Function to calculate output and padding size for 2D/3D convolution.
//
// Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function also calculates
- // left, right, top and bottom pads. Function does not return any status -
- // status is returned via context status.
+ // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
+ // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
+ // NDHWC||NCDHW(Conv3D) format depending on data format.
+ // Function also calculates left, right, top and bottom pads.
+ // Function does not return any status which is set with context status.
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void GetOutputAndPadSizeInMklOrder(
const TensorShape& input_shape, const TensorShape& filter_shape,
const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order,
- memory::dims* output_dims_mkl_order, memory::dims* pad_l,
- memory::dims* pad_r) {
+ memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
+ memory::dims* pad_l, memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
- int input_rows = GetTensorDim(input_shape, data_format_, 'H');
- int input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ bool isConv2D = (strides_.size() == 4);
+ int input_planes, input_rows, input_cols;
+ if (isConv2D) {
+ input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ } else {
+ input_planes = GetTensorDim(input_shape, data_format_, '0');
+ input_rows = GetTensorDim(input_shape, data_format_, '1');
+ input_cols = GetTensorDim(input_shape, data_format_, '2');
+ }
- // The first dimension for filter is rows/height.
- int filter_rows = filter_shape.dim_size(0);
- // The second dimension for filter is cols/width.
- int filter_cols = filter_shape.dim_size(1);
+ // Filter dimension
+ // Conv2D:
+ // First dimension: rows/height.
+ // Second dimension: cols/width.
+ // Conv3D:
+ // First dimension: planes/depth.
+ // Second dimension: rows/height.
+ // Third dimension: cols/width.
+
+ int filter_planes, filter_rows, filter_cols;
+ if (isConv2D) {
+ filter_rows = filter_shape.dim_size(0);
+ filter_cols = filter_shape.dim_size(1);
+ } else {
+ filter_planes = filter_shape.dim_size(0);
+ filter_rows = filter_shape.dim_size(1);
+ filter_cols = filter_shape.dim_size(2);
+ }
- // Stride is vector of 2 elements: {s_r, s_c}
- int stride_rows = strides[0];
- int stride_cols = strides[1];
- int dilation_rows = dilations[0];
- int dilation_cols = dilations[1];
+ int stride_planes, stride_rows, stride_cols;
+ int dilation_planes, dilation_rows, dilation_cols;
+ if (isConv2D) {
+ // Conv2D stride is a vector of 2 elements: {s_r, s_c}
+ stride_rows = strides[0];
+ stride_cols = strides[1];
+ dilation_rows = dilations[0];
+ dilation_cols = dilations[1];
+ } else {
+ // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
+ stride_planes = strides[0];
+ stride_rows = strides[1];
+ stride_cols = strides[2];
+ dilation_planes = dilations[0];
+ dilation_rows = dilations[1];
+ dilation_cols = dilations[2];
+ }
// Output batch is same as input batch.
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
+
// Output depth is same as last dimension for filter.
- int out_depth = filter_shape.dim_size(3);
+ int out_depth = filter_shape.dim_size(isConv2D ? 3 : 4);
- int64 out_rows = 0, out_cols = 0;
+ int64 out_rows = 0, out_cols = 0, out_planes = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
+ int64 pad_D1, pad_D2;
+
+ if (isConv2D) {
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_rows, filter_rows, dilation_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_cols, filter_cols, dilation_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ } else {
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_planes, filter_planes, stride_planes,
+ padding_, &out_planes, &pad_D1, &pad_D2));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ }
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_rows, filter_rows,
- dilation_rows, stride_rows, padding_,
- &out_rows, &pad_top, &pad_bottom));
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_cols, filter_cols,
- dilation_cols, stride_cols, padding_,
- &out_cols, &pad_left, &pad_right));
-
- // Tensorflow output is in data_format order. (NHWC or NCHW)
+ // Tensorflow output is in data_format order.
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
+ // MKL-DNN uses asymetric padding.
TensorShape out_shape =
- ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ isConv2D
+ ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
+ out_depth)
+ : ShapeFromFormat(data_format_, out_batch,
+ {{out_planes, out_rows, out_cols}}, out_depth);
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
- // MKL-DNN always needs output in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
- mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
- *output_dims_mkl_order = mkldnn_sizes;
-
- // Now handle padding. MKL-DNN uses asymetric padding.
- *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
- *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ if (isConv2D) {
+ // For Conv2D, MKL-DNN always needs output in NCHW format.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ } else {
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top),
+ static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom),
+ static_cast<int>(pad_right)};
+ }
}
// Calculate output and pad size of forward Convolution operator.
@@ -292,10 +424,10 @@ class MklDnnConvUtil {
//
// Function does not return anything, but sets error in context status.
inline void GetOutputAndPadSizeInMklOrder(
- size_t src_index, size_t filter_index,
- const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
- memory::dims* pad_l, memory::dims* pad_r) {
+ size_t src_index, size_t filter_index, const memory::dims& strides,
+ const memory::dims& dilations, memory::dims* output_dims_tf_order,
+ memory::dims* output_dims_mkl_order, memory::dims* pad_l,
+ memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
@@ -304,9 +436,17 @@ class MklDnnConvUtil {
auto input_tf_shape = GetTfShape(context_, src_index);
auto filter_tf_shape = GetTfShape(context_, filter_index);
- OP_REQUIRES(context_, input_tf_shape.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input_tf_shape.DebugString()));
+ if (strides_.size() == 4) {
+ // Conv2D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input_tf_shape.DebugString()));
+ } else {
+ // Conv3D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input_tf_shape.DebugString()));
+ }
GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape,
strides, dilations, output_dims_tf_order,
@@ -314,9 +454,11 @@ class MklDnnConvUtil {
}
// Wrapper function to calculate input, filter, and output sizes of
- // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
- // Function also calculates output shape in Tensorflow order. Additionally, it
- // also calculates strides and paddings for 2D Convolution.
+ // Conv2D/Conv3D in MKL order:
+ // Conv2D: NCHW for input and output; OIHW for filter.
+ // Conv3D: NCDHW for input and output; OIDHW for filter.
+ // Function also calculates output shape in Tensorflow order.
+ // Additionally, it also calculates strides and paddings.
//
// Function does not return anything, but sets error in context status.
inline void GetConvFwdSizesInMklOrder(
@@ -349,16 +491,15 @@ class MklDnnConvUtil {
}
};
-
/////////////////////////////////////////////////////////////////////
-/// Common class that implements Conv2DBackpropFilter and Input
+/// Common class that implements ConvBackpropFilter and Input
/////////////////////////////////////////////////////////////////////
template <typename Device, class T>
-class MklConv2DBackpropCommonOp : public OpKernel {
+class MklConvBackpropCommonOp : public OpKernel {
public:
- ~MklConv2DBackpropCommonOp() {}
- explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context)
+ ~MklConvBackpropCommonOp() {}
+ explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format_str;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
@@ -372,20 +513,25 @@ class MklConv2DBackpropCommonOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ // Check Conv2D dilations
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
+
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index e0f25fb4ef..385021b168 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1736,6 +1736,87 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklConv3D")
+ .Input("input: T")
+ .Input("filter: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter: uint8")
+ .Output("output: T")
+ .Output("filter_output: T")
+ .Output("mkl_output: uint8")
+ .Output("mkl_filter_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn(shape_inference::Conv3DShape)
+ .Doc(R"doc(
+MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropInputV2")
+ .Input("input_sizes: Tshape")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Input("mkl_input_sizes: uint8")
+ .Input("mkl_filter: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropFilterV2")
+ .Input("input: T")
+ .Input("filter_sizes: int32")
+ .Input("out_backprop: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter_size: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the filter.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
REGISTER_OP("_MklRelu")
.Input("features: T")
.Input("mkl_features: uint8")
@@ -2161,7 +2242,7 @@ REGISTER_OP("_MklToTf")
.Input("mkl_input: uint8")
.Output("output: T")
.Attr("T: {half, float, double}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to convert a tensor from MKL layout to TensorFlow layout.
@@ -2183,7 +2264,7 @@ REGISTER_OP("_MklInputConversion")
.Attr(
"T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
"complex64, complex128}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to process the inputs to an elementwise MKL op. Both inputs
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 159a787d05..422be9356d 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -87,6 +87,16 @@ typedef enum {
Dim_I = 1
} MklDnnDims;
+typedef enum {
+ Dim3d_N = 0,
+ Dim3d_C = 1,
+ Dim3d_D = 2,
+ Dim3d_H = 3,
+ Dim3d_W = 4,
+ Dim3d_O = 0,
+ Dim3d_I = 1
+} MklDnnDims3D;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -351,6 +361,7 @@ class MklShape {
#else
// Forward decl
+TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
@@ -453,6 +464,13 @@ class MklDnnShape {
return this->DimSize(index);
}
+ inline size_t GetDimension3D(char dimension) const {
+ int index = GetMklDnnTensor3DDimIndex(dimension);
+ CHECK(index >= 0 && index < this->GetDimension())
+ << "Invalid index from the dimension: " << index << ", " << dimension;
+ return this->DimSize(index);
+ }
+
inline int32 GetMklDnnTensorDimIndex(char dimension) const {
switch (dimension) {
case 'N':
@@ -469,6 +487,24 @@ class MklDnnShape {
}
}
+ inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
+ switch (dimension) {
+ case 'N':
+ return MklDnnDims3D::Dim3d_N;
+ case 'C':
+ return MklDnnDims3D::Dim3d_C;
+ case 'D':
+ return MklDnnDims3D::Dim3d_D;
+ case 'H':
+ return MklDnnDims3D::Dim3d_H;
+ case 'W':
+ return MklDnnDims3D::Dim3d_W;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ }
+
inline size_t GetDimension() const { return data_.dimension_; }
inline const int* GetSizes() const {
return reinterpret_cast<const int*>(&data_.sizes_[0]);
@@ -587,13 +623,26 @@ class MklDnnShape {
}
inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
- // TODO(nhasabni): Why do we restrict this to 4D?
- CHECK_EQ(dimension, 4);
- CHECK(dimension == data_.dimension_);
- data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
- data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
- data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
- data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ if (dimension == 5) {
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
+ MklDnnDims3D::Dim3d_D;
+ data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
+ MklDnnDims3D::Dim3d_H;
+ data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
+ MklDnnDims3D::Dim3d_W;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
+ MklDnnDims3D::Dim3d_C;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
+ MklDnnDims3D::Dim3d_N;
+ } else {
+ CHECK_EQ(dimension, 4);
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ }
}
inline void SetTfDimOrder(const size_t dimension, memory::format format) {
@@ -1329,6 +1378,19 @@ memory::data_type MklDnnType<float>() {
return memory::data_type::f32;
}
+/// Map TensorFlow's data format into MKL-DNN 3D data format
+/// @input: TensorFlow data format
+/// @return: memory::format corresponding to TensorFlow data format;
+/// Fails with an error if invalid data format.
+inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
+ if (format == FORMAT_NHWC)
+ return memory::format::ndhwc;
+ else if (format == FORMAT_NCHW)
+ return memory::format::ncdhw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ return memory::format::format_undef;
+}
+
/// Map TensorFlow's data format into MKL-DNN data format
///
/// @input: TensorFlow data format
@@ -1340,7 +1402,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
else if (format == FORMAT_NCHW)
return memory::format::nchw;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
- // Return to get rid of compiler warning
return memory::format::format_undef;
}
@@ -1350,9 +1411,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
/// @return: Tensorflow data format corresponding to memory::format
/// Fails with an error if invalid data format.
inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
- if (format == memory::format::nhwc)
+ if (format == memory::format::nhwc || format == memory::format::ndhwc)
return FORMAT_NHWC;
- else if (format == memory::format::nchw)
+ else if (format == memory::format::nchw || format == memory::format::ncdhw)
return FORMAT_NCHW;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
@@ -1402,6 +1463,22 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
return memory::dims({n, c, h, w});
}
+inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
+ TensorFormat format) {
+ // Check validity of format.
+ CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
+ memory::format::format_undef);
+
+ int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
+ int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
+ int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
+ int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
+ int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
+
+ // MKL-DNN requires dimensions in NCDHW format.
+ return memory::dims({n, c, d, h, w});
+}
+
/// Overloaded version of function above. Input parameters are
/// self-explanatory.
inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
@@ -1514,6 +1591,8 @@ class MklDnnData {
/// Operations memory descriptor
memory::desc* op_md_;
+ // flat to indicate if data is 3D or not.
+ bool bIs3D;
/// Operations temp buffer
void* allocated_buffer_;
/// CPU engine on which operation will be executed
@@ -1540,6 +1619,10 @@ class MklDnnData {
static_cast<const void*>(tensor->flat<T>().data()));
}
+ void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
+
+ bool GetIs3D() { return bIs3D; }
+
/// Set user memory primitive using specified dimensions, memory format and
/// data_buffer. Function automatically uses element data type by using
/// input type T used for creating call object.
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index a5f7ecf0d1..f331973f5c 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() {
return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' ";
}
+string GetConvnetDataFormat2D3DAttrString() {
+ return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' ";
+}
+
string GetConvnetFilterFormatAttrString() {
return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' ";
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 918835e1fb..b0c349dd90 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString();
// Return the string that specifies the filter format for convnet operations.
string GetConvnetFilterFormatAttrString();
string GetConvnet3dFilterFormatAttrString();
+string GetConvnetDataFormat2D3DAttrString();
// Returns a tensor shape for the specified format and dimension sizes.
// Works for both 2D and 3D operations. The output shapes are as follows: