diff options
author | 2018-08-16 13:04:52 -0700 | |
---|---|---|
committer | 2018-08-16 13:05:05 -0700 | |
commit | 9c50882415cb87a7eb81048d42401c64bf0617ef (patch) | |
tree | c550925b2d9e7f6997ace0e3bb3268572f7066b7 | |
parent | 19cafed2ae69ce5cbc4d2b2fc9176fb4c550040f (diff) | |
parent | 62191da0819b25906c1b2ed96159cfe36ba00383 (diff) |
Merge pull request #21324 from Intel-tensorflow:conv3d
PiperOrigin-RevId: 209032082
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 39 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc | 174 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 144 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.cc | 157 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.h | 414 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 85 | ||||
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 103 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 1 |
9 files changed, 756 insertions, 365 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 5683944e46..833592caab 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2418,6 +2418,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; csinfo_.conv2d_grad_filter_with_bias = "__MklDummyConv2DBackpropFilterWithBias"; + csinfo_.conv3d = "Conv3D"; + csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2"; + csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2"; csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; csinfo_.identity = "Identity"; @@ -2468,18 +2471,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsConcatV2, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, - csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D, + csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_input, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv3d, + mkl_op_registry::GetMklOpName(csinfo_.conv3d), + CopyAttrsConv, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv3d_grad_filter, + mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter), + CopyAttrsConv, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv3d_grad_input, + mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input), + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.fused_batch_norm, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite}); @@ -2614,6 +2626,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string conv2d_grad_input; string conv2d_grad_filter; string conv2d_grad_filter_with_bias; + string conv3d; + string conv3d_grad_input; + string conv3d_grad_filter; string fused_batch_norm; string fused_batch_norm_grad; string identity; @@ -3086,7 +3101,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); @@ -3571,14 +3586,13 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( // Op-specific functions to copy attributes from old node to new node ////////////////////////////////////////////////////////////////////////// -void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, - NodeBuilder* nb) { +void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, + NodeBuilder* nb) { DataType T; string data_format; string padding; std::vector<int32> strides; std::vector<int32> dilations; - bool use_cudnn_on_gpu; // Get all attributes from old node. TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); @@ -3586,8 +3600,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - TF_CHECK_OK( - GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); // Add attributes to new node. nb->Attr("T", T); @@ -3595,7 +3607,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, nb->Attr("dilations", dilations); nb->Attr("padding", padding); nb->Attr("data_format", data_format); - nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); } void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, @@ -3896,7 +3907,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd // Copy attributes from Conv2D to Conv2DWithBias. - CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); + CopyAttrsConv(const_cast<const Node*>(pred), &nb); // Copy the device assigned to old node to new node. nb.Device(succ->def().device()); @@ -4007,7 +4018,7 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( } // Copy attributes from Conv2DBackpropFilter. - CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb); + CopyAttrsConv(const_cast<const Node*>(fltr), &nb); // Copy the device assigned to old node to new node. nb.Device(fltr->def().device()); diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 50c25e1da7..afbfaa83f3 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -82,11 +82,11 @@ struct MklConvBwdFilterParams { }; template <typename T> -class MklConv2DBwdFilterPrimitive : public MklPrimitive { +class MklConvBwdFilterPrimitive : public MklPrimitive { public: - explicit MklConv2DBwdFilterPrimitive( - const MklConvBwdFilterParams& convBwdFilterDims) : - cpu_engine_(engine::cpu, 0) { + explicit MklConvBwdFilterPrimitive( + const MklConvBwdFilterParams& convBwdFilterDims) + : cpu_engine_(engine::cpu, 0) { context_.bwd_filter_stream.reset(new stream(stream::kind::eager)); // create conv primitive if (context_.conv_bwd_filter == nullptr) { @@ -94,7 +94,7 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive { } } - ~MklConv2DBwdFilterPrimitive() {} + ~MklConvBwdFilterPrimitive() {} // Convolution backward weights with bias // src_data: input data buffer of src @@ -297,38 +297,36 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive { }; template <typename T> -class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { +class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { public: - static MklConv2DBwdFilterPrimitive<T>* Get( + static MklConvBwdFilterPrimitive<T>* Get( const MklConvBwdFilterParams& convBwdFilterDims) { - MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr; + MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr; // look into the pool for reusable primitive - conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> ( - MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter( - convBwdFilterDims)); - - if (conv2d_bwd_filter == nullptr) { - conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>( - convBwdFilterDims); - MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter( - convBwdFilterDims, conv2d_bwd_filter); + conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>( + MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter( + convBwdFilterDims)); + + if (conv_bwd_filter == nullptr) { + conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims); + MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter( + convBwdFilterDims, conv_bwd_filter); } - return conv2d_bwd_filter; + return conv_bwd_filter; } - private: - MklConv2DBwdFilterPrimitiveFactory() {} - ~MklConv2DBwdFilterPrimitiveFactory() {} + MklConvBwdFilterPrimitiveFactory() {} + ~MklConvBwdFilterPrimitiveFactory() {} - static MklConv2DBwdFilterPrimitiveFactory& GetInstance() { - static MklConv2DBwdFilterPrimitiveFactory instance_; + static MklConvBwdFilterPrimitiveFactory& GetInstance() { + static MklConvBwdFilterPrimitiveFactory instance_; return instance_; } static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) { - string prefix = "conv2d_bwd_filter"; + string prefix = "conv_bwd_filter"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(convBwdFilterDims.src_dims); @@ -342,14 +340,14 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { return key_creator.GetKey(); } - MklPrimitive* GetConv2dBwdFilter( + MklPrimitive* GetConvBwdFilter( const MklConvBwdFilterParams& convBwdFilterDims) { string key = CreateKey(convBwdFilterDims); return this->GetOp(key); } - void SetConv2dBwdFilter( - const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) { + void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims, + MklPrimitive* op) { string key = CreateKey(convBwdFilterDims); this->SetOp(key, op); } @@ -738,14 +736,13 @@ TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); #else template <typename Device, class T, bool biasEnabled> -class MklConv2DCustomBackpropFilterOp - : public MklConv2DBackpropCommonOp<Device, T> { +class MklConvCustomBackpropFilterOp + : public MklConvBackpropCommonOp<Device, T> { public: - explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp<Device, T>(context) { - } + explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context) + : MklConvBackpropCommonOp<Device, T>(context) {} - ~MklConv2DCustomBackpropFilterOp() {} + ~MklConvCustomBackpropFilterOp() {} void Compute(OpKernelContext* context) { try { @@ -753,6 +750,9 @@ class MklConv2DCustomBackpropFilterOp MklDnnData<T> diff_dst(&cpu_engine_); MklDnnData<T> diff_filter(&cpu_engine_); // output + // This flag indicates Conv2D or Conv3D + bool isConv2D = (this->strides_.size() == 4); + // Input tensors const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; const Tensor& src_tensor = MklGetInput(context, kInputIdx); @@ -813,7 +813,10 @@ class MklConv2DCustomBackpropFilterOp &fwd_dst_dims, &padding_left, &padding_right); if (!context->status().ok()) return; - auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_); + auto tf_fmt = isConv2D + ? TFDataFormatToMklDnnDataFormat(this->data_format_) + : TFDataFormatToMklDnn3DDataFormat(this->data_format_); + auto fwd_src_md = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetMklLayout() @@ -832,21 +835,19 @@ class MklConv2DCustomBackpropFilterOp if (biasEnabled) { TensorShape obp_tf_shape = GetTfShape(context, 2); depth = (this->data_format_ == FORMAT_NCHW) - ? obp_tf_shape.dim_size(1) - : obp_tf_shape.dim_size(3); + ? obp_tf_shape.dim_size(1) + : obp_tf_shape.dim_size(isConv2D ? 3 : 4); diff_bias_dims = {static_cast<int>(depth)}; } + for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1; - dilations[kDilationH] -= 1; - dilations[kDilationW] -= 1; - - MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr; + MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr; MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides, dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); - conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get( - convBwdFilterDims); - auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc(); + conv_bwd_filter = + MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims); + auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc(); // allocate output tensors: diff_fitler and diff_bias (w bias) auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims); @@ -854,14 +855,26 @@ class MklConv2DCustomBackpropFilterOp // diff_filter MklDnnShape diff_filter_mkl_shape; diff_filter_mkl_shape.SetMklTensor(false); - // output_dims_mkl_order is in OIHW format. - TensorShape diff_filter_tf_shape( - {bwd_output_dims[MklDnnDims::Dim_H], - bwd_output_dims[MklDnnDims::Dim_W], - bwd_output_dims[MklDnnDims::Dim_I], - bwd_output_dims[MklDnnDims::Dim_O]}); - AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, - diff_filter_tf_shape, diff_filter_mkl_shape); + + if (isConv2D) { + // Conv2D: output_dims_mkl_order is in OIHW format. + TensorShape diff_filter_tf_shape({bwd_output_dims[MklDnnDims::Dim_H], + bwd_output_dims[MklDnnDims::Dim_W], + bwd_output_dims[MklDnnDims::Dim_I], + bwd_output_dims[MklDnnDims::Dim_O]}); + AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, + diff_filter_tf_shape, diff_filter_mkl_shape); + } else { + // Conv3D: output_dims_mkl_order is in OIDHW format. + TensorShape diff_filter_tf_shape( + {bwd_output_dims[MklDnnDims3D::Dim3d_D], + bwd_output_dims[MklDnnDims3D::Dim3d_H], + bwd_output_dims[MklDnnDims3D::Dim3d_W], + bwd_output_dims[MklDnnDims3D::Dim3d_I], + bwd_output_dims[MklDnnDims3D::Dim3d_O]}); + AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, + diff_filter_tf_shape, diff_filter_mkl_shape); + } Tensor* diff_bias_tensor = nullptr; if (biasEnabled) { @@ -871,7 +884,7 @@ class MklConv2DCustomBackpropFilterOp // check if src and diff_dst need reorder T *src_data = nullptr; - if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) { + if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) { src.SetUsrMem(fwd_src_md, &src_tensor); src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc()); src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); @@ -882,7 +895,7 @@ class MklConv2DCustomBackpropFilterOp T *diff_dst_data = nullptr; if (diff_dst_md.data.format != - conv2d_bwd_filter->GetDiffDstMemoryFormat()) { + conv_bwd_filter->GetDiffDstMemoryFormat()) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc()); diff_dst_data = static_cast<T*>( @@ -897,7 +910,7 @@ class MklConv2DCustomBackpropFilterOp bool diff_filter_reorder_required = false; T *diff_filter_data = nullptr; if (GetOutputFormat(tf_fmt) != - conv2d_bwd_filter->GetDiffFilterMemoryFormat()) { + conv_bwd_filter->GetDiffFilterMemoryFormat()) { // Allocate diff filter tensor as Tensorflow layout diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt), diff_filter_tensor); @@ -915,10 +928,10 @@ class MklConv2DCustomBackpropFilterOp if (biasEnabled) { T* diff_bias_data = static_cast<T*>(const_cast<T*>( diff_bias_tensor->flat<T>().data())); - conv2d_bwd_filter->Execute(src_data, diff_filter_data, - diff_bias_data, diff_dst_data); + conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data, + diff_dst_data); } else { - conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data); + conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data); } // Reorder diff_filter back to Tensorflow layout if necessary @@ -947,7 +960,7 @@ class MklConv2DCustomBackpropFilterOp const MklDnnShape& filter_mkl_shape, const MklDnnShape& obp_mkl_shape) { CHECK(!filter_mkl_shape.IsMklTensor()) - << "Conv2DBackpropFilter: filter should not be in MKL Layout"; + << "ConvBackpropFilter: filter should not be in MKL Layout"; } // Get TensorFlow shape of input tensor. @@ -983,9 +996,11 @@ class MklConv2DCustomBackpropFilterOp return fwd_filter_dims; } - // Output layout is Tensorflow's filter layout (HWIO). + // Output layout is Tensorflow's filter layout + // Conv2D: HWIO; Conv3D: DHWIO memory::format GetOutputFormat(const memory::format data_format) { - return memory::format::hwio; + return (this->strides_.size() == 4) ? memory::format::hwio + : memory::format::dhwio; } // Allocate output tensor. @@ -1027,24 +1042,27 @@ class MklConv2DCustomBackpropFilterOp } }; -#define REGISTER_MKL_FILTER_KERNELS(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklConv2DBackpropFilter") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklConv2DBackpropFilterWithBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>); \ - REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklDummyOp<CPUDevice, T>); +#define REGISTER_MKL_FILTER_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvCustomBackpropFilterOp<CPUDevice, T, false>); \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilterWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvCustomBackpropFilterOp<CPUDevice, T, true>); \ + REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklDummyOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvCustomBackpropFilterOp<CPUDevice, T, false>); TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); #undef REGISTER_MKL_FILTER_KERNELS diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 38e014d68e..b5a98301e2 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; #ifndef INTEL_MKL_ML_ONLY -/// utility classes enabling primitive reuse for backward conv2d ops. +/// utility classes enabling primitive reuse for backward conv ops. struct MklConvBwdInputParams { memory::dims diff_src_dims; memory::dims filter_dims; @@ -83,11 +83,11 @@ struct MklConvBwdInputParams { }; template <typename T> -class MklConv2DBwdInputPrimitive : public MklPrimitive { +class MklConvBwdInputPrimitive : public MklPrimitive { public: - explicit MklConv2DBwdInputPrimitive( - const MklConvBwdInputParams& convBwdInputDims) : - cpu_engine_(engine::cpu, 0) { + explicit MklConvBwdInputPrimitive( + const MklConvBwdInputParams& convBwdInputDims) + : cpu_engine_(engine::cpu, 0) { context_.bwd_input_stream.reset(new stream(stream::kind::eager)); // create conv primitive @@ -95,7 +95,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive { Setup(convBwdInputDims); } } - ~MklConv2DBwdInputPrimitive() {} + ~MklConvBwdInputPrimitive() {} // Convolution backward filter (weights) // diff_src_data: output data buffer of diff_src @@ -134,7 +134,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive { } private: - // Primitive reuse context for Conv2D Bwd Input op + // Primitive reuse context for Conv Bwd Input op struct ConvBwdInputContext { // expected memory format for this primitive instance memory::format filter_fmt; @@ -235,38 +235,37 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive { }; template <typename T> -class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { +class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { private: - MklConv2DBwdInputPrimitiveFactory() {} - ~MklConv2DBwdInputPrimitiveFactory() {} + MklConvBwdInputPrimitiveFactory() {} + ~MklConvBwdInputPrimitiveFactory() {} public: - static MklConv2DBwdInputPrimitive<T>* Get( + static MklConvBwdInputPrimitive<T>* Get( const MklConvBwdInputParams& convBwdInputDims) { - MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr; + MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr; // look into the pool for reusable primitive - conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> ( - MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput( + conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>( + MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput( convBwdInputDims)); - if (conv2d_bwd_input == nullptr) { - conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>( - convBwdInputDims); - MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput( - convBwdInputDims, conv2d_bwd_input); + if (conv_bwd_input == nullptr) { + conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims); + MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput( + convBwdInputDims, conv_bwd_input); } - return conv2d_bwd_input; + return conv_bwd_input; } private: - static MklConv2DBwdInputPrimitiveFactory& GetInstance() { - static MklConv2DBwdInputPrimitiveFactory instance_; + static MklConvBwdInputPrimitiveFactory& GetInstance() { + static MklConvBwdInputPrimitiveFactory instance_; return instance_; } static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) { - string prefix = "conv2d_bwd_input"; + string prefix = "conv_bwd_input"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(convBwdInputDims.diff_src_dims); @@ -279,14 +278,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { return key_creator.GetKey(); } - MklPrimitive* GetConv2dBwdInput( - const MklConvBwdInputParams& convBwdInputDims) { + MklPrimitive* GetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims) { string key = CreateKey(convBwdInputDims); return this->GetOp(key); } - void SetConv2dBwdInput( - const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) { + void SetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims, + MklPrimitive* op) { string key = CreateKey(convBwdInputDims); this->SetOp(key, op); } @@ -594,23 +592,34 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { TensorFormat data_format; }; +#define REGISTER_MKL_CPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConv2DCustomBackpropInputOp<CPUDevice, T>); + +TF_CALL_float(REGISTER_MKL_CPU_KERNELS); +#undef REGISTER_MKL_CPU_KERNELS + #else template <typename Device, class T> -class MklConv2DCustomBackpropInputOp - : public MklConv2DBackpropCommonOp<Device, T> { +class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> { public: - explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp<Device, T>(context) { - } + explicit MklConvCustomBackpropInputOp(OpKernelConstruction* context) + : MklConvBackpropCommonOp<Device, T>(context) {} - ~MklConv2DCustomBackpropInputOp() {} + ~MklConvCustomBackpropInputOp() {} void Compute(OpKernelContext* context) { try { MklDnnData<T> filter(&cpu_engine); MklDnnData<T> diff_dst(&cpu_engine); + // This flag indicate Conv2D or Conv3D + bool isConv2D = (this->strides_.size() == 4); + // Input tensors const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; const Tensor& src_tensor = MklGetInput(context, kInputIdx); @@ -626,7 +635,7 @@ class MklConv2DCustomBackpropInputOp diff_dst_mkl_shape); // Allow operator-specific generation of shapes. - // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a + // E.g., ConvBackpropFilter gets filter as filter_sizes. It is a // tensor containing shape of filter. So filter.shape() is not // a correct way to get filter shape. These operator-specific calls // allow this class to handle this case. @@ -655,6 +664,7 @@ class MklConv2DCustomBackpropInputOp } return; } + // By default, all dims are in MKL order. Only dims in TF order // are those with postfix tf_order. memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims; @@ -673,15 +683,18 @@ class MklConv2DCustomBackpropInputOp // Create Convolution forward descriptor since Convolution backward // API needs it. For that, we first need to create input, filter // and output memory descriptors. - auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_); + auto tf_fmt = isConv2D + ? TFDataFormatToMklDnnDataFormat(this->data_format_) + : TFDataFormatToMklDnn3DDataFormat(this->data_format_); // If filter is in MKL layout, then simply grab filter layout; // otherwise, construct filter in TF layout. // For TF layout, filter is in HWIO format. auto fwd_filter_md = filter_mkl_shape.IsMklTensor() - ? filter_mkl_shape.GetMklLayout() - : memory::desc(fwd_filter_dims, MklDnnType<T>(), - memory::format::hwio); + ? filter_mkl_shape.GetMklLayout() + : memory::desc(fwd_filter_dims, MklDnnType<T>(), + isConv2D ? memory::format::hwio + : memory::format::dhwio); conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); if (!context->status().ok()) return; @@ -689,18 +702,15 @@ class MklConv2DCustomBackpropInputOp ? diff_dst_mkl_shape.GetMklLayout() : memory::desc(diff_dst_dims, MklDnnType<T>(), tf_fmt); + for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1; - dilations[kDilationH] -= 1; - dilations[kDilationW] -= 1; - - MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr; - conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); + MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr; MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); - conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get( - convBwdInputDims); - auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc(); + conv_bwd_input = + MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims); + auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc(); // allocate output tensor auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc(); @@ -723,7 +733,7 @@ class MklConv2DCustomBackpropInputOp // check if filter and diff_dst need reorder T* filter_data = nullptr; if (fwd_filter_md.data.format != - conv2d_bwd_input->GetFilterMemoryFormat()) { + conv_bwd_input->GetFilterMemoryFormat()) { filter.SetUsrMem(fwd_filter_md, &filter_tensor); filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc()); filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle()); @@ -733,8 +743,7 @@ class MklConv2DCustomBackpropInputOp } T* diff_dst_data = nullptr; - if (diff_dst_md.data.format != - conv2d_bwd_input->GetDiffDstMemoryFormat()) { + if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc()); diff_dst_data = static_cast<T*>( @@ -745,7 +754,7 @@ class MklConv2DCustomBackpropInputOp } // execute convolution input bwd - conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); + conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -770,7 +779,7 @@ class MklConv2DCustomBackpropInputOp // of the Tensor and never an actual tensor. So it will never be in MKL // layout. CHECK(!input_mkl_shape.IsMklTensor()) - << "Conv2DBackpropInput: input should not be in MKL Layout"; + << "ConvBackpropInput: input should not be in MKL Layout"; } // Get TensorFlow shape of input tensor. @@ -778,10 +787,10 @@ class MklConv2DCustomBackpropInputOp const Tensor& input_tensor) { TensorShape input_tf_shape; CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true); - CHECK_EQ( - TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape) - .ok(), - true); + // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64 + // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for + // input_tensor. + CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true); return input_tf_shape; } @@ -792,7 +801,7 @@ class MklConv2DCustomBackpropInputOp } // Get the Tensorflow shape of Output (diff_src), - // which is same as shape of Conv2D 'input'. + // which is same as shape of Conv 'input'. TensorShape GetOutputTfShape(const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& outbprop_shape) { @@ -800,7 +809,7 @@ class MklConv2DCustomBackpropInputOp } // Get the Tensorflow shape of Output (diff_src), - // which is same as shape of Conv2D 'input'. + // which is same as shape of Conv 'input'. const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, const memory::dims& fwd_filter_dims) { return fwd_input_dims; @@ -839,17 +848,22 @@ class MklConv2DCustomBackpropInputOp } }; -#endif // INTEL_MKL_ML_ONLY - -#define REGISTER_MKL_CPU_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DCustomBackpropInputOp<CPUDevice, T>); +#define REGISTER_MKL_CPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvCustomBackpropInputOp<CPUDevice, T>); TF_CALL_float(REGISTER_MKL_CPU_KERNELS); #undef REGISTER_MKL_CPU_KERNELS +#endif // INTEL_MKL_ML_ONLY + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index bca1aa21a8..c6295c7280 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -85,9 +85,9 @@ struct MklConvFwdParams { }; template <typename T> -class MklConv2DFwdPrimitive : public MklPrimitive { +class MklConvFwdPrimitive : public MklPrimitive { public: - explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims) + explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) : cpu_engine_(engine::cpu, 0) { context_.fwd_stream.reset(new stream(stream::kind::eager)); // create conv primitive @@ -96,7 +96,7 @@ class MklConv2DFwdPrimitive : public MklPrimitive { } } - ~MklConv2DFwdPrimitive() {} + ~MklConvFwdPrimitive() {} // Convolution forward execute with bias // src_data: input data buffer of src @@ -269,37 +269,36 @@ class MklConv2DFwdPrimitive : public MklPrimitive { }; template <typename T> -class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { +class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> { public: - static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) { - MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; + static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) { + MklConvFwdPrimitive<T>* conv_fwd = nullptr; // try to find a suitable one in pool - conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>( - MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd( - convFwdDims)); - - if (conv2d_fwd == nullptr) { - conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims); - MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims, - conv2d_fwd); + conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>( + MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims)); + + if (conv_fwd == nullptr) { + conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims); + MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims, + conv_fwd); } - return conv2d_fwd; + return conv_fwd; } private: - MklConv2DFwdPrimitiveFactory() {} - ~MklConv2DFwdPrimitiveFactory() {} + MklConvFwdPrimitiveFactory() {} + ~MklConvFwdPrimitiveFactory() {} static const int kDilationH = 0, kDilationW = 1; - static MklConv2DFwdPrimitiveFactory& GetInstance() { - static MklConv2DFwdPrimitiveFactory instance_; + static MklConvFwdPrimitiveFactory& GetInstance() { + static MklConvFwdPrimitiveFactory instance_; return instance_; } static string CreateKey(const MklConvFwdParams& convFwdDims) { - string prefix = "conv2d_fwd_"; + string prefix = "conv_fwd_"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(convFwdDims.src_dims); @@ -313,12 +312,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { return key_creator.GetKey(); } - MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) { + MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) { string key = CreateKey(convFwdDims); return this->GetOp(key); } - void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { + void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { string key = CreateKey(convFwdDims); this->SetOp(key, op); } @@ -331,11 +330,11 @@ typedef Eigen::ThreadPoolDevice CPUDevice; // For now, MKL-ML is default. So making MKL-DNN not a default choice. #ifdef INTEL_MKL_ML_ONLY template <typename Device, typename T, bool biasEnabled> -class MklConv2DOp : public OpKernel { +class MklConvOp : public OpKernel { public: - ~MklConv2DOp() {} + ~MklConvOp() {} - explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) { + explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); @@ -755,21 +754,22 @@ class MklConv2DOp : public OpKernel { #else +// Base class for convolution forward operations template <typename Device, typename T, bool biasEnabled> -class MklConv2DOp : public OpKernel { +class MklConvOp : public OpKernel { public: - ~MklConv2DOp() {} + ~MklConvOp() {} - explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) { + explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); OP_REQUIRES(context, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(context, strides_.size() == 4, + OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); + "specify 4 or 5 dimensions")); const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); @@ -778,20 +778,39 @@ class MklConv2DOp : public OpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); - const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); - const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); - const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); - const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); - OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - errors::InvalidArgument( - "Current implementation does not yet support " - "dilations in the batch and depth dimensions.")); - OP_REQUIRES( - context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); + + if (strides_.size() == 4) { + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + } else if (strides_.size() == 5) { + OP_REQUIRES(context, dilations_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilations_, data_format_, 'N') == 1 && + GetTensorDim(dilations_, data_format_, 'C') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations rates in the batch and depth dimensions.")); + OP_REQUIRES( + context, + (GetTensorDim(dilations_, data_format_, '0') > 0 && + GetTensorDim(dilations_, data_format_, '1') > 0 && + GetTensorDim(dilations_, data_format_, '2') > 0), + errors::InvalidArgument("Dilated rates should be larger than 0.")); + } } void Compute(OpKernelContext* context) override { @@ -837,7 +856,8 @@ class MklConv2DOp : public OpKernel { AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor, src_tf_shape, dst_mkl_shape); - // MklConv2D also outputs converted filter as 2nd output of Conv2D. + // MklConv2D/3D also outputs converted filter + // as 2nd output of Conv2D/3D. filter_mkl_shape.SetMklTensor(false); Tensor* output_filter_tensor = nullptr; AllocateOutputSetMklShape(context, kOutputIndex_Filter, @@ -846,15 +866,20 @@ class MklConv2DOp : public OpKernel { return; } + bool isConv2D = (strides_.size() == 4); + // Create memory for user data. // Describe how the inputs and outputs of Convolution look like. Also // specify buffers containing actual input and output data. - auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_); + auto tf_fmt = isConv2D ? TFDataFormatToMklDnnDataFormat(data_format_) + : TFDataFormatToMklDnn3DDataFormat(data_format_); // If input is in MKL layout, then simply grab input layout; otherwise, // construct input Tf layout. For TF layout, although input shape // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's - // layout (NHWC or NCHW depending on data format). + // layout depending on data format: + // Conv2D: NHWC or NCHW + // Conv3D: NDHWC or NCDHW auto src_md = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetMklLayout() : memory::desc(src_dims, MklDnnType<T>(), tf_fmt); @@ -864,31 +889,30 @@ class MklConv2DOp : public OpKernel { auto filter_md = filter_mkl_shape.IsMklTensor() // Should NEVER be true ? filter_mkl_shape.GetMklLayout() : memory::desc(filter_dims, MklDnnType<T>(), - memory::format::hwio); - + isConv2D ? memory::format::hwio + : memory::format::dhwio); // MKLDNN dilation starts from 0. - dilations[kDilationH] -= 1; - dilations[kDilationW] -= 1; + for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1; // get a conv2d fwd from primitive pool - MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr; + MklConvFwdPrimitive<T>* conv_fwd = nullptr; if (biasEnabled) { memory::dims bias_dims = {}; conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims, dst_dims_mkl_order, strides, dilations, padding_left, padding_right); - conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims); + conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims); } else { MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS, dst_dims_mkl_order, strides, dilations, padding_left, padding_right); - conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims); + conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims); } // allocate output tensors output_tensor and filter_out_tensor std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd = - conv2d_fwd->GetPrimitiveDesc(); + conv_fwd->GetPrimitiveDesc(); AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, &dst_tensor); Tensor* filter_out_tensor = nullptr; @@ -900,7 +924,7 @@ class MklConv2DOp : public OpKernel { // check whether src/filter need reorder T *src_data = nullptr; - if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) { + if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) { src.SetUsrMem(src_md, &src_tensor); src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc()); src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); @@ -908,7 +932,7 @@ class MklConv2DOp : public OpKernel { src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); } T* filter_data = nullptr; - if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) { + if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) { filter.SetUsrMem(filter_md, &filter_tensor); filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(), filter.GetTensorBuffer(filter_out_tensor)); @@ -918,16 +942,15 @@ class MklConv2DOp : public OpKernel { static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data())); } - // execute convolution if (biasEnabled) { const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); T* bias_data = static_cast<T*>(const_cast<T*>( bias_tensor.flat<T>().data())); - conv2d_fwd->Execute(src_data, filter_data, bias_data, dst_data); + conv_fwd->Execute(src_data, filter_data, bias_data, dst_data); } else { - conv2d_fwd->Execute(src_data, filter_data, dst_data); + conv_fwd->Execute(src_data, filter_data, dst_data); } } catch (mkldnn::error &e) { string error_msg = tensorflow::strings::StrCat( @@ -1038,17 +1061,18 @@ class MklConv2DOp : public OpKernel { #endif +// Register 2D operations #define REGISTER_MKL_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DOp<CPUDevice, T, false>); \ + MklConvOp<CPUDevice, T, false>); \ REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DOp<CPUDevice, T, true>); \ + MklConvOp<CPUDevice, T, true>); \ REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ @@ -1057,5 +1081,14 @@ class MklConv2DOp : public OpKernel { TF_CALL_float(REGISTER_MKL_CPU); +// Register 3D operations +#define REGISTER_MKL_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvOp<CPUDevice, T, false>); +TF_CALL_float(REGISTER_MKL_CPU); + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index 838c06f49d..01cc606f41 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -79,9 +79,16 @@ class MklDnnConvUtil { // For now we take the stride from the second and third dimensions only // (we do not support striding on the batch or depth dimension). CHECK_NOTNULL(strides); - int stride_rows = GetTensorDim(strides_, data_format_, 'H'); - int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - *strides = {stride_rows, stride_cols}; + if (strides_.size() == 4) { + int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + *strides = {stride_rows, stride_cols}; + } else if (strides_.size() == 5) { + int stride_planes = GetTensorDim(strides_, data_format_, '0'); + int stride_rows = GetTensorDim(strides_, data_format_, '1'); + int stride_cols = GetTensorDim(strides_, data_format_, '2'); + *strides = {stride_planes, stride_rows, stride_cols}; + } } // Calculate Convolution dilations @@ -89,13 +96,20 @@ class MklDnnConvUtil { // For now we take the dilation from the second and third dimensions only // (we do not support dilation on the batch or depth dimension). CHECK_NOTNULL(dilations); - int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); - int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); - *dilations = {dilations_rows, dilations_cols}; + if (dilations_.size() == 4) { + int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); + int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); + *dilations = {dilations_rows, dilations_cols}; + } else if (dilations_.size() == 5) { + int dilations_planes = GetTensorDim(dilations_, data_format_, '0'); + int dilations_rows = GetTensorDim(dilations_, data_format_, '1'); + int dilations_cols = GetTensorDim(dilations_, data_format_, '2'); + *dilations = {dilations_planes, dilations_rows, dilations_cols}; + } } // Calculate Convolution input size in MKL-DNN order. MKL-DNN - // requires input in NCHW format. Function does not return anything. + // requires input in NCHW/NCDHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, @@ -113,40 +127,62 @@ class MklDnnConvUtil { int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C'); int input_depth = static_cast<int>(input_depth_raw); - // Input rows/height - int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); - CHECK_BOUNDS(input_rows_raw, "Input rows too large"); - int input_rows = static_cast<int>(input_rows_raw); - - // Input columns/width - int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); - CHECK_BOUNDS(input_cols_raw, "Input cols too large"); - int input_cols = static_cast<int>(input_cols_raw); - // Input batch int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N'); CHECK_BOUNDS(input_batch_raw, "Input batch too large"); int input_batch = static_cast<int>(input_batch_raw); + if (strides_.size() == 4) { // NCHW format for Conv2D + // Input rows/height + int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); + CHECK_BOUNDS(input_rows_raw, "Input rows too large"); + int input_rows = static_cast<int>(input_rows_raw); + + // Input columns/width + int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); + CHECK_BOUNDS(input_cols_raw, "Input cols too large"); + int input_cols = static_cast<int>(input_cols_raw); + + // MKL-DNN always requires input in NCHW format Conv2D. + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; + mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; + mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; + + *input_dims = mkldnn_sizes; + } else if (strides_.size() == 5) { // NCDHW format for Conv3D + // Input planes/third-dimension + int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0'); + CHECK_BOUNDS(input_planes_raw, "Input depth too large"); + int input_planes = static_cast<int>(input_planes_raw); + + // Input rows/height + int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1'); + CHECK_BOUNDS(input_rows_raw, "Input rows too large"); + int input_rows = static_cast<int>(input_rows_raw); + + // Input columns/width + int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2'); + CHECK_BOUNDS(input_cols_raw, "Input cols too large"); + int input_cols = static_cast<int>(input_cols_raw); + + // MKL-DNN always requires input in NCDHW format for Conv3D. + std::vector<int> mkldnn_sizes(5, -1); + mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch; + mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes; + mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows; + mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols; + + *input_dims = mkldnn_sizes; + } #undef CHECK_BOUNDS - - // MKL-DNN always requires input in NCHW format. - std::vector<int> mkldnn_sizes(4, -1); - mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; - mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; - mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; - mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; - - *input_dims = mkldnn_sizes; } - // Calculate Convolution filter size in MKL-DNN order. MKL-DNN - // requires filter in OIHW format. Function does not return anything. - // But errors arising from sanity checks are returned in context's - // status. - // - // Calculate Convolution filter size in MKL-DNN order. MKL-DNN - // requires filter in OIHW format. Function does not return anything. + // Calculate Convolution filter size in MKL-DNN order. + // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format. + // Function does not return anything. // But errors arising from sanity checks are returned in context's // status. This function differs from GetConvFilterSizeInMklOrder in // parameter for input - it accepts src_shape since Convolution Backward @@ -159,11 +195,13 @@ class MklDnnConvUtil { memory::dims* filter_dims) { CHECK_NOTNULL(filter_dims); - OP_REQUIRES(context_, filter_shape.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", + OP_REQUIRES(context_, filter_shape.dims() == strides_.size(), + errors::InvalidArgument((strides_.size() == 4) + ? "filter must be 4-dimensional: " + : "filter must be 5-dimensional: ", filter_shape.DebugString())); - for (int i = 0; i < 3; i++) { + for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i), std::numeric_limits<int>::max()), @@ -172,32 +210,57 @@ class MklDnnConvUtil { int input_depth = GetTensorDim(input_shape, data_format_, 'C'); - OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter_shape.dim_size(2))); - - // TF filter is always in (rows, cols, in_depth, out_depth) order. - int filter_rows = static_cast<int>(filter_shape.dim_size(0)); - int filter_cols = static_cast<int>(filter_shape.dim_size(1)); - int in_depth = static_cast<int>(filter_shape.dim_size(2)); - int out_depth = static_cast<int>(filter_shape.dim_size(3)); - - // MKL-DNN always needs filter in OIHW format. - // OIHW = (out_depth, in_depth, rows, cols) - std::vector<int> mkldnn_sizes(4, -1); - mkldnn_sizes[MklDnnDims::Dim_O] = out_depth; - mkldnn_sizes[MklDnnDims::Dim_I] = in_depth; - mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; - mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; - - *filter_dims = mkldnn_sizes; + if (strides_.size() == 4) { // Conv2D + OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", + input_depth, " vs ", filter_shape.dim_size(2))); + + // TF filter is always in (rows, cols, in_depth, out_depth) order. + int filter_rows = static_cast<int>(filter_shape.dim_size(0)); + int filter_cols = static_cast<int>(filter_shape.dim_size(1)); + int in_depth = static_cast<int>(filter_shape.dim_size(2)); + int out_depth = static_cast<int>(filter_shape.dim_size(3)); + + // MKL-DNN always needs filter in OIHW format. + // OIHW = (out_depth, in_depth, rows, cols) + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_O] = out_depth; + mkldnn_sizes[MklDnnDims::Dim_I] = in_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; + mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; + + *filter_dims = mkldnn_sizes; + } else { // Conv3D + OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3), + errors::InvalidArgument( + "input and filter must have the same depth: ", + input_depth, " vs ", filter_shape.dim_size(3))); + + // TF filter is always in (planes, rows, cols, in_depth, out_depth) order. + int filter_planes = static_cast<int>(filter_shape.dim_size(0)); + int filter_rows = static_cast<int>(filter_shape.dim_size(1)); + int filter_cols = static_cast<int>(filter_shape.dim_size(2)); + int in_depth = static_cast<int>(filter_shape.dim_size(3)); + int out_depth = static_cast<int>(filter_shape.dim_size(4)); + + // MKL-DNN always needs filter in OIDHW format. + // OIDHW = (out_depth, in_depth, planes, rows, cols) + std::vector<int> mkldnn_sizes(5, -1); + mkldnn_sizes[MklDnnDims3D::Dim3d_O] = out_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_I] = in_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; + mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows; + mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols; + + *filter_dims = mkldnn_sizes; + } } - // Calculate Convolution filter size in MKL-DNN order. MKL-DNN - // requires filter in OIHW format. Function does not return anything. - // But errors arising from sanity checks are returned in context's - // status. + // Calculate Convolution filter size in MKL-DNN order. + // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format. + // Function does not return anything. But errors arising from sanity + // checks are returned in context's status. virtual inline void GetFilterSizeInMklOrder(size_t src_index, size_t filter_index, memory::dims* filter_dims) { @@ -206,8 +269,8 @@ class MklDnnConvUtil { GetTfShape(context_, filter_index), filter_dims); } - // Calculate Bias size for 2D Convolution. Function does not return - // anything, but sets error in context status. + // Calculate Bias size for 2D or 3D Convolution. Function does not + // return anything, but may set an error in context status. virtual inline void GetBiasSizeInMklOrder(size_t bias_index, memory::dims* bias_dims) { const Tensor& bias = MklGetInput(context_, bias_index); @@ -218,73 +281,142 @@ class MklDnnConvUtil { *bias_dims = {static_cast<int>(bias.dim_size(0))}; } - // Function to calculate output and padding size for 2D convolution. + // Function to calculate output and padding size for 2D/3D convolution. // // Calculate output shape of Convolution in MKL-DNN and TensorFlow order. - // MKL-DNN uses NCHW for output order. But TensorFlow output will be in - // NHWC or NCHW format depending on data format. Function also calculates - // left, right, top and bottom pads. Function does not return any status - - // status is returned via context status. + // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order. + // But TensorFlow output will be in NHWC||NCHW(Conv2D) or + // NDHWC||NCDHW(Conv3D) format depending on data format. + // Function also calculates left, right, top and bottom pads. + // Function does not return any status which is set with context status. // // TODO(nhasabni): Add similar function for input and filter in MklShape. virtual inline void GetOutputAndPadSizeInMklOrder( const TensorShape& input_shape, const TensorShape& filter_shape, const memory::dims& strides, const memory::dims& dilations, - memory::dims* output_dims_tf_order, - memory::dims* output_dims_mkl_order, memory::dims* pad_l, - memory::dims* pad_r) { + memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, + memory::dims* pad_l, memory::dims* pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); CHECK_NOTNULL(pad_r); - int input_rows = GetTensorDim(input_shape, data_format_, 'H'); - int input_cols = GetTensorDim(input_shape, data_format_, 'W'); + bool isConv2D = (strides_.size() == 4); + int input_planes, input_rows, input_cols; + if (isConv2D) { + input_rows = GetTensorDim(input_shape, data_format_, 'H'); + input_cols = GetTensorDim(input_shape, data_format_, 'W'); + } else { + input_planes = GetTensorDim(input_shape, data_format_, '0'); + input_rows = GetTensorDim(input_shape, data_format_, '1'); + input_cols = GetTensorDim(input_shape, data_format_, '2'); + } - // The first dimension for filter is rows/height. - int filter_rows = filter_shape.dim_size(0); - // The second dimension for filter is cols/width. - int filter_cols = filter_shape.dim_size(1); + // Filter dimension + // Conv2D: + // First dimension: rows/height. + // Second dimension: cols/width. + // Conv3D: + // First dimension: planes/depth. + // Second dimension: rows/height. + // Third dimension: cols/width. + + int filter_planes, filter_rows, filter_cols; + if (isConv2D) { + filter_rows = filter_shape.dim_size(0); + filter_cols = filter_shape.dim_size(1); + } else { + filter_planes = filter_shape.dim_size(0); + filter_rows = filter_shape.dim_size(1); + filter_cols = filter_shape.dim_size(2); + } - // Stride is vector of 2 elements: {s_r, s_c} - int stride_rows = strides[0]; - int stride_cols = strides[1]; - int dilation_rows = dilations[0]; - int dilation_cols = dilations[1]; + int stride_planes, stride_rows, stride_cols; + int dilation_planes, dilation_rows, dilation_cols; + if (isConv2D) { + // Conv2D stride is a vector of 2 elements: {s_r, s_c} + stride_rows = strides[0]; + stride_cols = strides[1]; + dilation_rows = dilations[0]; + dilation_cols = dilations[1]; + } else { + // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c} + stride_planes = strides[0]; + stride_rows = strides[1]; + stride_cols = strides[2]; + dilation_planes = dilations[0]; + dilation_rows = dilations[1]; + dilation_cols = dilations[2]; + } // Output batch is same as input batch. int out_batch = GetTensorDim(input_shape, data_format_, 'N'); + // Output depth is same as last dimension for filter. - int out_depth = filter_shape.dim_size(3); + int out_depth = filter_shape.dim_size(isConv2D ? 3 : 4); - int64 out_rows = 0, out_cols = 0; + int64 out_rows = 0, out_cols = 0, out_planes = 0; int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right; + int64 pad_D1, pad_D2; + + if (isConv2D) { + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerboseV2( + input_rows, filter_rows, dilation_rows, stride_rows, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerboseV2( + input_cols, filter_cols, dilation_cols, stride_cols, + padding_, &out_cols, &pad_left, &pad_right)); + } else { + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_planes, filter_planes, stride_planes, + padding_, &out_planes, &pad_D1, &pad_D2)); + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_rows, filter_rows, stride_rows, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_cols, filter_cols, stride_cols, + padding_, &out_cols, &pad_left, &pad_right)); + } - OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2(input_rows, filter_rows, - dilation_rows, stride_rows, padding_, - &out_rows, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context_, - GetWindowedOutputSizeVerboseV2(input_cols, filter_cols, - dilation_cols, stride_cols, padding_, - &out_cols, &pad_left, &pad_right)); - - // Tensorflow output is in data_format order. (NHWC or NCHW) + // Tensorflow output is in data_format order. + // Conv2D: NHWC or NCHW + // Conv3D: NDHWC or NCDHW + // MKL-DNN uses asymetric padding. TensorShape out_shape = - ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth); + isConv2D + ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, + out_depth) + : ShapeFromFormat(data_format_, out_batch, + {{out_planes, out_rows, out_cols}}, out_depth); *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); - // MKL-DNN always needs output in NCHW format. - std::vector<int> mkldnn_sizes(4, -1); - mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; - mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; - mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); - mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); - *output_dims_mkl_order = mkldnn_sizes; - - // Now handle padding. MKL-DNN uses asymetric padding. - *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; - *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; + if (isConv2D) { + // For Conv2D, MKL-DNN always needs output in NCHW format. + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; + mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); + mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); + *output_dims_mkl_order = mkldnn_sizes; + + *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; + *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)}; + } else { + std::vector<int> mkldnn_sizes(5, -1); + mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch; + mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth; + mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes); + mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows); + mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols); + *output_dims_mkl_order = mkldnn_sizes; + + *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top), + static_cast<int>(pad_left)}; + *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom), + static_cast<int>(pad_right)}; + } } // Calculate output and pad size of forward Convolution operator. @@ -292,10 +424,10 @@ class MklDnnConvUtil { // // Function does not return anything, but sets error in context status. inline void GetOutputAndPadSizeInMklOrder( - size_t src_index, size_t filter_index, - const memory::dims& strides, const memory::dims& dilations, - memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, - memory::dims* pad_l, memory::dims* pad_r) { + size_t src_index, size_t filter_index, const memory::dims& strides, + const memory::dims& dilations, memory::dims* output_dims_tf_order, + memory::dims* output_dims_mkl_order, memory::dims* pad_l, + memory::dims* pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); @@ -304,9 +436,17 @@ class MklDnnConvUtil { auto input_tf_shape = GetTfShape(context_, src_index); auto filter_tf_shape = GetTfShape(context_, filter_index); - OP_REQUIRES(context_, input_tf_shape.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input_tf_shape.DebugString())); + if (strides_.size() == 4) { + // Conv2D + OP_REQUIRES(context_, input_tf_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_tf_shape.DebugString())); + } else { + // Conv3D + OP_REQUIRES(context_, input_tf_shape.dims() == 5, + errors::InvalidArgument("input must be 5-dimensional", + input_tf_shape.DebugString())); + } GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, dilations, output_dims_tf_order, @@ -314,9 +454,11 @@ class MklDnnConvUtil { } // Wrapper function to calculate input, filter, and output sizes of - // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.) - // Function also calculates output shape in Tensorflow order. Additionally, it - // also calculates strides and paddings for 2D Convolution. + // Conv2D/Conv3D in MKL order: + // Conv2D: NCHW for input and output; OIHW for filter. + // Conv3D: NCDHW for input and output; OIDHW for filter. + // Function also calculates output shape in Tensorflow order. + // Additionally, it also calculates strides and paddings. // // Function does not return anything, but sets error in context status. inline void GetConvFwdSizesInMklOrder( @@ -349,16 +491,15 @@ class MklDnnConvUtil { } }; - ///////////////////////////////////////////////////////////////////// -/// Common class that implements Conv2DBackpropFilter and Input +/// Common class that implements ConvBackpropFilter and Input ///////////////////////////////////////////////////////////////////// template <typename Device, class T> -class MklConv2DBackpropCommonOp : public OpKernel { +class MklConvBackpropCommonOp : public OpKernel { public: - ~MklConv2DBackpropCommonOp() {} - explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context) + ~MklConvBackpropCommonOp() {} + explicit MklConvBackpropCommonOp(OpKernelConstruction* context) : OpKernel(context) { string data_format_str; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); @@ -372,20 +513,25 @@ class MklConv2DBackpropCommonOp : public OpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); - int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); - int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); - int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); - int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); - OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), - errors::InvalidArgument( - "Current implementation does not yet support " - "dilations in the batch and depth dimensions.")); - OP_REQUIRES( - context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); + + if (strides_.size() == 4) { + // Check Conv2D dilations + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index e0f25fb4ef..385021b168 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1736,6 +1736,87 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +REGISTER_OP("_MklConv3D") + .Input("input: T") + .Input("filter: T") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Output("output: T") + .Output("filter_output: T") + .Output("mkl_output: uint8") + .Output("mkl_filter_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") + .SetShapeFn(shape_inference::Conv3DShape) + .Doc(R"doc( +MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklConv3DBackpropInputV2") + .Input("input_sizes: Tshape") + .Input("filter: T") + .Input("out_backprop: T") + .Input("mkl_input_sizes: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_out_backprop: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int) >= 5") + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") + .Attr("Tshape: {int32, int64} = DT_INT32") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the +gradients of convolution with respect to the input. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklConv3DBackpropFilterV2") + .Input("input: T") + .Input("filter_sizes: int32") + .Input("out_backprop: T") + .Input("mkl_input: uint8") + .Input("mkl_filter_size: uint8") + .Input("mkl_out_backprop: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the +gradients of convolution with respect to the filter. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + REGISTER_OP("_MklRelu") .Input("features: T") .Input("mkl_features: uint8") @@ -2161,7 +2242,7 @@ REGISTER_OP("_MklToTf") .Input("mkl_input: uint8") .Output("output: T") .Attr("T: {half, float, double}") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( MKL operator to convert a tensor from MKL layout to TensorFlow layout. @@ -2183,7 +2264,7 @@ REGISTER_OP("_MklInputConversion") .Attr( "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " "complex64, complex128}") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( MKL operator to process the inputs to an elementwise MKL op. Both inputs diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 159a787d05..422be9356d 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -87,6 +87,16 @@ typedef enum { Dim_I = 1 } MklDnnDims; +typedef enum { + Dim3d_N = 0, + Dim3d_C = 1, + Dim3d_D = 2, + Dim3d_H = 3, + Dim3d_W = 4, + Dim3d_O = 0, + Dim3d_I = 1 +} MklDnnDims3D; + #ifdef INTEL_MKL_ML_ONLY class MklShape { public: @@ -351,6 +361,7 @@ class MklShape { #else // Forward decl +TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format); TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format); memory::dims CalculateTFStrides(const memory::dims& dims_tf_order); memory::desc CreateBlockedMemDescHelper(const memory::dims& dim, @@ -453,6 +464,13 @@ class MklDnnShape { return this->DimSize(index); } + inline size_t GetDimension3D(char dimension) const { + int index = GetMklDnnTensor3DDimIndex(dimension); + CHECK(index >= 0 && index < this->GetDimension()) + << "Invalid index from the dimension: " << index << ", " << dimension; + return this->DimSize(index); + } + inline int32 GetMklDnnTensorDimIndex(char dimension) const { switch (dimension) { case 'N': @@ -469,6 +487,24 @@ class MklDnnShape { } } + inline int32 GetMklDnnTensor3DDimIndex(char dimension) const { + switch (dimension) { + case 'N': + return MklDnnDims3D::Dim3d_N; + case 'C': + return MklDnnDims3D::Dim3d_C; + case 'D': + return MklDnnDims3D::Dim3d_D; + case 'H': + return MklDnnDims3D::Dim3d_H; + case 'W': + return MklDnnDims3D::Dim3d_W; + default: + LOG(FATAL) << "Invalid dimension: " << dimension; + return -1; // Avoid compiler warning about missing return value + } + } + inline size_t GetDimension() const { return data_.dimension_; } inline const int* GetSizes() const { return reinterpret_cast<const int*>(&data_.sizes_[0]); @@ -587,13 +623,26 @@ class MklDnnShape { } inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) { - // TODO(nhasabni): Why do we restrict this to 4D? - CHECK_EQ(dimension, 4); - CHECK(dimension == data_.dimension_); - data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W; - data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H; - data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C; - data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N; + if (dimension == 5) { + CHECK(dimension == data_.dimension_); + data_.map_[GetTensorDimIndex<3>(data_format, '0')] = + MklDnnDims3D::Dim3d_D; + data_.map_[GetTensorDimIndex<3>(data_format, '1')] = + MklDnnDims3D::Dim3d_H; + data_.map_[GetTensorDimIndex<3>(data_format, '2')] = + MklDnnDims3D::Dim3d_W; + data_.map_[GetTensorDimIndex<3>(data_format, 'C')] = + MklDnnDims3D::Dim3d_C; + data_.map_[GetTensorDimIndex<3>(data_format, 'N')] = + MklDnnDims3D::Dim3d_N; + } else { + CHECK_EQ(dimension, 4); + CHECK(dimension == data_.dimension_); + data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W; + data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H; + data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C; + data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N; + } } inline void SetTfDimOrder(const size_t dimension, memory::format format) { @@ -1329,6 +1378,19 @@ memory::data_type MklDnnType<float>() { return memory::data_type::f32; } +/// Map TensorFlow's data format into MKL-DNN 3D data format +/// @input: TensorFlow data format +/// @return: memory::format corresponding to TensorFlow data format; +/// Fails with an error if invalid data format. +inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { + if (format == FORMAT_NHWC) + return memory::format::ndhwc; + else if (format == FORMAT_NCHW) + return memory::format::ncdhw; + TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); + return memory::format::format_undef; +} + /// Map TensorFlow's data format into MKL-DNN data format /// /// @input: TensorFlow data format @@ -1340,7 +1402,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { else if (format == FORMAT_NCHW) return memory::format::nchw; TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); - // Return to get rid of compiler warning return memory::format::format_undef; } @@ -1350,9 +1411,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { /// @return: Tensorflow data format corresponding to memory::format /// Fails with an error if invalid data format. inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) { - if (format == memory::format::nhwc) + if (format == memory::format::nhwc || format == memory::format::ndhwc) return FORMAT_NHWC; - else if (format == memory::format::nchw) + else if (format == memory::format::nchw || format == memory::format::ncdhw) return FORMAT_NCHW; TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); @@ -1402,6 +1463,22 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape, return memory::dims({n, c, h, w}); } +inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape, + TensorFormat format) { + // Check validity of format. + CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format), + memory::format::format_undef); + + int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N')); + int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C')); + int d = shape.dim_size(GetTensorDimIndex<3>(format, '0')); + int h = shape.dim_size(GetTensorDimIndex<3>(format, '1')); + int w = shape.dim_size(GetTensorDimIndex<3>(format, '2')); + + // MKL-DNN requires dimensions in NCDHW format. + return memory::dims({n, c, d, h, w}); +} + /// Overloaded version of function above. Input parameters are /// self-explanatory. inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims, @@ -1514,6 +1591,8 @@ class MklDnnData { /// Operations memory descriptor memory::desc* op_md_; + // flat to indicate if data is 3D or not. + bool bIs3D; /// Operations temp buffer void* allocated_buffer_; /// CPU engine on which operation will be executed @@ -1540,6 +1619,10 @@ class MklDnnData { static_cast<const void*>(tensor->flat<T>().data())); } + void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; } + + bool GetIs3D() { return bIs3D; } + /// Set user memory primitive using specified dimensions, memory format and /// data_buffer. Function automatically uses element data type by using /// input type T used for creating call object. diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc index a5f7ecf0d1..f331973f5c 100644 --- a/tensorflow/core/util/tensor_format.cc +++ b/tensorflow/core/util/tensor_format.cc @@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() { return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' "; } +string GetConvnetDataFormat2D3DAttrString() { + return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' "; +} + string GetConvnetFilterFormatAttrString() { return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' "; } diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 918835e1fb..b0c349dd90 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString(); // Return the string that specifies the filter format for convnet operations. string GetConvnetFilterFormatAttrString(); string GetConvnet3dFilterFormatAttrString(); +string GetConvnetDataFormat2D3DAttrString(); // Returns a tensor shape for the specified format and dimension sizes. // Works for both 2D and 3D operations. The output shapes are as follows: |