diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 174 |
1 files changed, 103 insertions, 71 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 38e014d68e..a501ce2c93 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; #ifndef INTEL_MKL_ML_ONLY -/// utility classes enabling primitive reuse for backward conv2d ops. +/// utility classes enabling primitive reuse for backward conv ops. struct MklConvBwdInputParams { memory::dims diff_src_dims; memory::dims filter_dims; @@ -83,11 +83,11 @@ struct MklConvBwdInputParams { }; template <typename T> -class MklConv2DBwdInputPrimitive : public MklPrimitive { +class MklConvBwdInputPrimitive : public MklPrimitive { public: - explicit MklConv2DBwdInputPrimitive( - const MklConvBwdInputParams& convBwdInputDims) : - cpu_engine_(engine::cpu, 0) { + explicit MklConvBwdInputPrimitive( + const MklConvBwdInputParams& convBwdInputDims) + : cpu_engine_(engine::cpu, 0) { context_.bwd_input_stream.reset(new stream(stream::kind::eager)); // create conv primitive @@ -95,7 +95,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive { Setup(convBwdInputDims); } } - ~MklConv2DBwdInputPrimitive() {} + ~MklConvBwdInputPrimitive() {} // Convolution backward filter (weights) // diff_src_data: output data buffer of diff_src @@ -134,7 +134,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive { } private: - // Primitive reuse context for Conv2D Bwd Input op + // Primitive reuse context for Conv Bwd Input op struct ConvBwdInputContext { // expected memory format for this primitive instance memory::format filter_fmt; @@ -174,7 +174,6 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive { } }; - void Setup(const MklConvBwdInputParams& convBwdInputDims) { // create memory descriptors for convolution data w/ no specified format context_.diff_src_md.reset(new memory::desc( @@ -235,38 +234,41 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive { }; template <typename T> -class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { +class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { private: - MklConv2DBwdInputPrimitiveFactory() {} - ~MklConv2DBwdInputPrimitiveFactory() {} + MklConvBwdInputPrimitiveFactory() {} + ~MklConvBwdInputPrimitiveFactory() {} public: - static MklConv2DBwdInputPrimitive<T>* Get( - const MklConvBwdInputParams& convBwdInputDims) { - MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr; - - // look into the pool for reusable primitive - conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> ( - MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput( - convBwdInputDims)); - - if (conv2d_bwd_input == nullptr) { - conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>( - convBwdInputDims); - MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput( - convBwdInputDims, conv2d_bwd_input); + static MklConvBwdInputPrimitive<T>* Get( + const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) { + MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr; + + if (do_not_cache) { /* Always allocate primitive */ + conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims); + } else { + // look into the pool for reusable primitive + conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>( + MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput( + convBwdInputDims)); + if (conv_bwd_input == nullptr) { + conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims); + MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput( + convBwdInputDims, conv_bwd_input); + } } - return conv2d_bwd_input; + + return conv_bwd_input; } private: - static MklConv2DBwdInputPrimitiveFactory& GetInstance() { - static MklConv2DBwdInputPrimitiveFactory instance_; + static MklConvBwdInputPrimitiveFactory& GetInstance() { + static MklConvBwdInputPrimitiveFactory instance_; return instance_; } static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) { - string prefix = "conv2d_bwd_input"; + string prefix = "conv_bwd_input"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(convBwdInputDims.diff_src_dims); @@ -279,14 +281,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { return key_creator.GetKey(); } - MklPrimitive* GetConv2dBwdInput( - const MklConvBwdInputParams& convBwdInputDims) { + MklPrimitive* GetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims) { string key = CreateKey(convBwdInputDims); return this->GetOp(key); } - void SetConv2dBwdInput( - const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) { + void SetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims, + MklPrimitive* op) { string key = CreateKey(convBwdInputDims); this->SetOp(key, op); } @@ -594,23 +595,34 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { TensorFormat data_format; }; +#define REGISTER_MKL_CPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConv2DCustomBackpropInputOp<CPUDevice, T>); + +TF_CALL_float(REGISTER_MKL_CPU_KERNELS); +#undef REGISTER_MKL_CPU_KERNELS + #else template <typename Device, class T> -class MklConv2DCustomBackpropInputOp - : public MklConv2DBackpropCommonOp<Device, T> { +class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> { public: - explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp<Device, T>(context) { - } + explicit MklConvCustomBackpropInputOp(OpKernelConstruction* context) + : MklConvBackpropCommonOp<Device, T>(context) {} - ~MklConv2DCustomBackpropInputOp() {} + ~MklConvCustomBackpropInputOp() {} void Compute(OpKernelContext* context) { try { MklDnnData<T> filter(&cpu_engine); MklDnnData<T> diff_dst(&cpu_engine); + // This flag indicate Conv2D or Conv3D + bool isConv2D = (this->strides_.size() == 4); + // Input tensors const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; const Tensor& src_tensor = MklGetInput(context, kInputIdx); @@ -626,7 +638,7 @@ class MklConv2DCustomBackpropInputOp diff_dst_mkl_shape); // Allow operator-specific generation of shapes. - // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a + // E.g., ConvBackpropFilter gets filter as filter_sizes. It is a // tensor containing shape of filter. So filter.shape() is not // a correct way to get filter shape. These operator-specific calls // allow this class to handle this case. @@ -655,6 +667,7 @@ class MklConv2DCustomBackpropInputOp } return; } + // By default, all dims are in MKL order. Only dims in TF order // are those with postfix tf_order. memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims; @@ -673,15 +686,18 @@ class MklConv2DCustomBackpropInputOp // Create Convolution forward descriptor since Convolution backward // API needs it. For that, we first need to create input, filter // and output memory descriptors. - auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_); + auto tf_fmt = isConv2D + ? TFDataFormatToMklDnnDataFormat(this->data_format_) + : TFDataFormatToMklDnn3DDataFormat(this->data_format_); // If filter is in MKL layout, then simply grab filter layout; // otherwise, construct filter in TF layout. // For TF layout, filter is in HWIO format. auto fwd_filter_md = filter_mkl_shape.IsMklTensor() - ? filter_mkl_shape.GetMklLayout() - : memory::desc(fwd_filter_dims, MklDnnType<T>(), - memory::format::hwio); + ? filter_mkl_shape.GetMklLayout() + : memory::desc(fwd_filter_dims, MklDnnType<T>(), + isConv2D ? memory::format::hwio + : memory::format::dhwio); conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); if (!context->status().ok()) return; @@ -689,18 +705,25 @@ class MklConv2DCustomBackpropInputOp ? diff_dst_mkl_shape.GetMklLayout() : memory::desc(diff_dst_dims, MklDnnType<T>(), tf_fmt); + for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1; - dilations[kDilationH] -= 1; - dilations[kDilationW] -= 1; - - MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr; - conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); + MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr; MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); - conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get( - convBwdInputDims); - auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc(); + + // We don't cache those primitves if the env variable + // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor + // includes potentialy large buffers. MKL DNN allocates buffers + // in the following cases + // 1. Legacy CPU without AVX512/AVX2, or + // 2. 1x1 convolution with stride != 1 + bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() && + (MklPrimitiveFactory<T>::IsLegacyPlatform() || + IsConv1x1StrideNot1(fwd_filter_dims, strides)); + conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims, + do_not_cache); + auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc(); // allocate output tensor auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc(); @@ -723,7 +746,7 @@ class MklConv2DCustomBackpropInputOp // check if filter and diff_dst need reorder T* filter_data = nullptr; if (fwd_filter_md.data.format != - conv2d_bwd_input->GetFilterMemoryFormat()) { + conv_bwd_input->GetFilterMemoryFormat()) { filter.SetUsrMem(fwd_filter_md, &filter_tensor); filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc()); filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle()); @@ -733,8 +756,7 @@ class MklConv2DCustomBackpropInputOp } T* diff_dst_data = nullptr; - if (diff_dst_md.data.format != - conv2d_bwd_input->GetDiffDstMemoryFormat()) { + if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc()); diff_dst_data = static_cast<T*>( @@ -745,7 +767,12 @@ class MklConv2DCustomBackpropInputOp } // execute convolution input bwd - conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); + conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); + + // delete primitive since it is not cached. + if (do_not_cache) { + delete conv_bwd_input; + } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -770,7 +797,7 @@ class MklConv2DCustomBackpropInputOp // of the Tensor and never an actual tensor. So it will never be in MKL // layout. CHECK(!input_mkl_shape.IsMklTensor()) - << "Conv2DBackpropInput: input should not be in MKL Layout"; + << "ConvBackpropInput: input should not be in MKL Layout"; } // Get TensorFlow shape of input tensor. @@ -778,10 +805,10 @@ class MklConv2DCustomBackpropInputOp const Tensor& input_tensor) { TensorShape input_tf_shape; CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true); - CHECK_EQ( - TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape) - .ok(), - true); + // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64 + // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for + // input_tensor. + CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true); return input_tf_shape; } @@ -792,7 +819,7 @@ class MklConv2DCustomBackpropInputOp } // Get the Tensorflow shape of Output (diff_src), - // which is same as shape of Conv2D 'input'. + // which is same as shape of Conv 'input'. TensorShape GetOutputTfShape(const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& outbprop_shape) { @@ -800,7 +827,7 @@ class MklConv2DCustomBackpropInputOp } // Get the Tensorflow shape of Output (diff_src), - // which is same as shape of Conv2D 'input'. + // which is same as shape of Conv 'input'. const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, const memory::dims& fwd_filter_dims) { return fwd_input_dims; @@ -839,17 +866,22 @@ class MklConv2DCustomBackpropInputOp } }; -#endif // INTEL_MKL_ML_ONLY - -#define REGISTER_MKL_CPU_KERNELS(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .Label(mkl_op_registry::kMklOpLabel), \ - MklConv2DCustomBackpropInputOp<CPUDevice, T>); +#define REGISTER_MKL_CPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklConvCustomBackpropInputOp<CPUDevice, T>); TF_CALL_float(REGISTER_MKL_CPU_KERNELS); #undef REGISTER_MKL_CPU_KERNELS +#endif // INTEL_MKL_ML_ONLY + } // namespace tensorflow #endif // INTEL_MKL |