/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #ifndef INTEL_MKL_ML_ONLY #include "mkldnn.hpp" using mkldnn::algorithm; using mkldnn::eltwise_elu; using mkldnn::eltwise_relu; using mkldnn::eltwise_tanh; using mkldnn::memory; using mkldnn::prop_kind; using mkldnn::relu_backward; using mkldnn::relu_forward; using mkldnn::stream; using mkldnn::memory; #else #include "mkl_dnn.h" #include "mkl_dnn_types.h" #endif #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { #ifndef INTEL_MKL_ML_ONLY template class MklEltwiseFwdParams { public: memory::dims src_dims; // check if this is needed memory::desc src_md; algorithm alg_kind; T alpha; T beta; MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md, algorithm alg_kind, T alpha, T beta) : src_dims(src_dims), src_md(src_md), alg_kind(alg_kind), alpha(alpha), beta(beta) {} }; template class MklEltwiseFwdPrimitive : public MklPrimitive { public: explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams& fwdParams) : cpu_engine_(engine::cpu, 0) { // store expected format context_.src_fmt = static_cast(fwdParams.src_md.data.format); context_.fwd_stream.reset(new stream(stream::kind::eager)); // create eltwise primitive if (context_.eltwise_fwd == nullptr) { Setup(fwdParams); } } ~MklEltwiseFwdPrimitive() {} // Eltwise forward execute // src_data: input data buffer of src // dst_data: output data buffer of dst void Execute(const T* src_data, T* dst_data) { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); context_.fwd_stream->submit(context_.fwd_primitives); // after execution, set data handle back context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); } std::shared_ptr GetEltwiseFwdPd() { return context_.fwd_pd; } memory::format GetSrcMemoryFormat() { return context_.src_fmt; } private: // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh struct EltwiseFwdContext { // expected memory format for this primitive instance mkldnn::memory::format src_fmt; // MKLDNN memory std::shared_ptr src_mem; std::shared_ptr dst_mem; // desc & prmitive desc std::shared_ptr fwd_desc; std::shared_ptr fwd_pd; // memory desc std::shared_ptr src_md; std::shared_ptr dst_md; // memory primitive desc std::shared_ptr src_mpd; // Eltwise primitive std::shared_ptr eltwise_fwd; std::shared_ptr fwd_stream; std::vector fwd_primitives; EltwiseFwdContext() : src_fmt(memory::format::any), src_mem(nullptr), dst_mem(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), src_mpd(nullptr), eltwise_fwd(nullptr), fwd_stream(nullptr) {} }; // Eltwise forward primitive setup void Setup(const MklEltwiseFwdParams& fwdParams) { // create memory descriptors for eltwise data with specified format context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); context_.src_mpd.reset( new memory::primitive_desc(*context_.src_md, cpu_engine_)); // create a eltwise context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( prop_kind::forward, fwdParams.alg_kind, *context_.src_md, fwdParams.alpha, fwdParams.beta)); context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( *context_.fwd_desc, cpu_engine_)); // create memory primitive based on dummy data context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); context_.dst_mem.reset( new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); // create eltwise primitive and add it to net context_.eltwise_fwd.reset(new mkldnn::eltwise_forward( *context_.fwd_pd, *context_.src_mem, *context_.dst_mem)); context_.fwd_primitives.push_back(*context_.eltwise_fwd); } struct EltwiseFwdContext context_; engine cpu_engine_; }; template class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { public: static MklEltwiseFwdPrimitive* Get( const MklEltwiseFwdParams& fwdParams) { MklEltwiseFwdPrimitive* eltwise_forward = nullptr; auto src_fmt = static_cast(fwdParams.src_md.data.format); // Get a eltwise fwd primitive from the cached pool eltwise_forward = static_cast*>( MklEltwiseFwdPrimitiveFactory::GetInstance().GetEltwiseFwd(fwdParams, src_fmt)); if (eltwise_forward == nullptr) { eltwise_forward = new MklEltwiseFwdPrimitive(fwdParams); MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( fwdParams, src_fmt, eltwise_forward); } return eltwise_forward; } static MklEltwiseFwdPrimitiveFactory& GetInstance() { static MklEltwiseFwdPrimitiveFactory instance_; return instance_; } private: MklEltwiseFwdPrimitiveFactory() {} ~MklEltwiseFwdPrimitiveFactory() {} static string CreateKey(const MklEltwiseFwdParams& fwdParams, memory::format src_fmt) { string prefix = "eltwise_fwd"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(fwdParams.src_dims); key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); key_creator.AddAsKey(static_cast(fwdParams.alpha)); key_creator.AddAsKey(static_cast(fwdParams.beta)); key_creator.AddAsKey(static_cast(src_fmt)); return key_creator.GetKey(); } MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, memory::format src_fmt) { string key = CreateKey(fwdParams, src_fmt); return this->GetOp(key); } void SetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, memory::format src_fmt, MklPrimitive* op) { string key = CreateKey(fwdParams, src_fmt); this->SetOp(key, op); } }; template class MklEltwiseBwdParams { public: memory::dims src_dims; memory::desc common_md; algorithm alg_kind; T alpha; T beta; MklEltwiseBwdParams(const memory::dims& src_dims, const memory::desc& common_md, algorithm alg_kind, T alpha, T beta) : src_dims(src_dims), common_md(common_md), alg_kind(alg_kind), alpha(alpha), beta(beta) {} }; template class MklEltwiseBwdPrimitive : public MklPrimitive { public: explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams& bwdParams) : cpu_engine_(engine::cpu, 0) { context_.src_fmt = static_cast(bwdParams.common_md.data.format); context_.diff_dst_fmt = static_cast(bwdParams.common_md.data.format); context_.bwd_stream.reset(new stream(stream::kind::eager)); // create eltwise primitive if (context_.eltwise_bwd == nullptr) { Setup(bwdParams); } } ~MklEltwiseBwdPrimitive() {} // Eltwise backward execute // src_data: input data buffer of src // diff_dst_data: input data buffer of diff_dst // diff_src_data: output data buffer of diff_src void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); context_.bwd_stream->submit(context_.bwd_primitives); // after execution, set data handle back context_.src_mem->set_data_handle(DummyData); context_.diff_dst_mem->set_data_handle(DummyData); context_.diff_src_mem->set_data_handle(DummyData); } std::shared_ptr GetEltwiseBwdPd() { return context_.bwd_pd; } memory::format GetSrcMemoryFormat() { return context_.src_fmt; } memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; } private: // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh struct EltwiseBwdContext { // expected memory format for this primitive instance memory::format src_fmt; memory::format diff_dst_fmt; // MKLDNN memory std::shared_ptr src_mem; std::shared_ptr diff_dst_mem; std::shared_ptr diff_src_mem; // desc & prmitive desc std::shared_ptr bwd_desc; // memory desc std::shared_ptr src_md; std::shared_ptr diff_dst_md; std::shared_ptr common_md; // memory primitive desc std::shared_ptr src_mpd; std::shared_ptr diff_dst_mpd; // fwd primitive desc std::shared_ptr fwd_desc; std::shared_ptr fwd_pd; std::shared_ptr bwd_pd; // Eltwise primitive std::shared_ptr eltwise_bwd; std::shared_ptr bwd_stream; std::vector bwd_primitives; EltwiseBwdContext() : src_fmt(memory::format::any), diff_dst_fmt(memory::format::any), src_mem(nullptr), diff_dst_mem(nullptr), diff_src_mem(nullptr), src_md(nullptr), diff_dst_md(nullptr), common_md(nullptr), src_mpd(nullptr), diff_dst_mpd(nullptr), fwd_desc(nullptr), fwd_pd(nullptr), bwd_pd(nullptr), eltwise_bwd(nullptr), bwd_stream(nullptr) {} }; // Eltwise backward primitive setup void Setup(const MklEltwiseBwdParams& bwdParams) { // create memory descriptors for eltwise data w/ no specified format context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); context_.src_mpd.reset( new memory::primitive_desc(*context_.src_md, cpu_engine_)); context_.diff_dst_mpd.reset( new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_)); // create forward eltwise primitive context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md, bwdParams.alpha, bwdParams.beta)); context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( *context_.fwd_desc, cpu_engine_)); context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc( bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md, bwdParams.alpha, bwdParams.beta)); context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc( *context_.bwd_desc, cpu_engine_, *context_.fwd_pd)); // create memory primitive based on dummy data context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData)); context_.diff_src_mem.reset(new memory( context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData)); // create eltwise primitive and add it to net context_.eltwise_bwd.reset(new mkldnn::eltwise_backward( *context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem, *context_.diff_src_mem)); context_.bwd_primitives.push_back(*context_.eltwise_bwd); } struct EltwiseBwdContext context_; engine cpu_engine_; }; template class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { private: MklEltwiseBwdPrimitiveFactory() {} ~MklEltwiseBwdPrimitiveFactory() {} public: static MklEltwiseBwdPrimitive* Get( const MklEltwiseBwdParams& bwdParams) { MklEltwiseBwdPrimitive* eltwise_backward = nullptr; auto src_fmt = static_cast(bwdParams.common_md.data.format); auto diff_dst_fmt = static_cast(bwdParams.common_md.data.format); // try to find a suitable one in pool eltwise_backward = static_cast*>( MklEltwiseBwdPrimitiveFactory::GetInstance().GetEltwiseBwd( bwdParams, src_fmt, diff_dst_fmt)); if (eltwise_backward == nullptr) { eltwise_backward = new MklEltwiseBwdPrimitive(bwdParams); MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( bwdParams, src_fmt, diff_dst_fmt, eltwise_backward); } return eltwise_backward; } static MklEltwiseBwdPrimitiveFactory& GetInstance() { static MklEltwiseBwdPrimitiveFactory instance_; return instance_; } private: static string CreateKey(const MklEltwiseBwdParams& bwdParams, const memory::format& src_fmt, const memory::format& diff_dst_fmt) { string prefix = "eltwise_bwd"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(bwdParams.src_dims); key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); key_creator.AddAsKey(static_cast(bwdParams.alpha)); key_creator.AddAsKey(static_cast(bwdParams.beta)); key_creator.AddAsKey(static_cast(src_fmt)); key_creator.AddAsKey(static_cast(diff_dst_fmt)); return key_creator.GetKey(); } MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, const memory::format& src_fmt, const memory::format& diff_dst_fmt) { string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); return this->GetOp(key); } void SetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, const memory::format& src_fmt, const memory::format& diff_dst_fmt, MklPrimitive* op) { string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); this->SetOp(key, op); } }; #endif typedef Eigen::ThreadPoolDevice CPUDevice; struct MklReluHelpers { static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g, const Tensor& a) { OP_REQUIRES(context, a.IsSameSize(g), errors::InvalidArgument("g and a must be the same size")); } static bool ValidateSameSize(OpKernelContext* context, const Tensor& g, const Tensor& a) { ValidateSameSizeHelper(context, g, a); return context->status().ok(); } }; #ifdef INTEL_MKL_ML_ONLY template class MklReluOp : public OpKernel { public: ~MklReluOp() {} explicit MklReluOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { MklReluOpContext mkl_context; const Tensor& input = MklGetInput(context, 0); GetMklShape(context, 0, &mkl_context.input_shape); void* user_i = static_cast(const_cast(input.flat().data())); bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); if (!input_in_mkl_format && !input.dims()) { // handle the case of a scalar const TensorShape& o_shape = input.shape(); Tensor* out_tensor = nullptr; mkl_context.output_shape.SetMklTensor(false); AllocateOutputSetMklShape(context, 0, &out_tensor, o_shape, mkl_context.output_shape); void* out_o = static_cast(out_tensor->flat().data()); (static_cast(out_o))[0] = std::max((static_cast(user_i))[0], static_cast(0)); return; } // Generate size, stride for input if input is in MKL format. if (input_in_mkl_format) { mkl_context.in_dims = mkl_context.input_shape.GetDimension(); mkl_context.in_sizes = new size_t[mkl_context.in_dims]; mkl_context.in_strides = new size_t[mkl_context.in_dims]; for (int i = 0; i < mkl_context.in_dims; i++) { mkl_context.in_sizes[i] = mkl_context.input_shape.GetSizes()[i]; mkl_context.in_strides[i] = mkl_context.input_shape.GetStrides()[i]; } } else { mkl_context.in_dims = input.dims(); mkl_context.in_sizes = new size_t[mkl_context.in_dims]; mkl_context.in_strides = new size_t[mkl_context.in_dims]; for (int i = 0; i < mkl_context.in_dims; i++) { mkl_context.in_sizes[i] = input.dim_size((mkl_context.in_dims - 1) - i); } mkl_context.in_strides[0] = 1; for (int i = 1; i < mkl_context.in_dims; i++) { mkl_context.in_strides[i] = mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1]; } } float negative_slope = 0.0; mkl_context.MklCreateInputLayouts(context); CHECK_EQ(dnnReLUCreateForward_F32(&mkl_context.prim_relu_fwd, NULL, mkl_context.lt_input, negative_slope), E_SUCCESS); Tensor* output = nullptr; if (input_in_mkl_format) { TensorShape tf_shape; mkl_context.output_shape.SetMklTensor(true); mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_fwd, dnnResourceDst); mkl_context.output_shape.SetTfLayout( mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides); mkl_context.output_shape.SetTfDimOrder( mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap()); tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( mkl_context.output_shape.GetMklLayout())) / sizeof(T)); AllocateOutputSetMklShape(context, 0, &output, tf_shape, mkl_context.output_shape); } else { const TensorShape& o_shape = input.shape(); mkl_context.output_shape.SetMklTensor(false); AllocateOutputSetMklShape(context, 0, &output, o_shape, mkl_context.output_shape); } void* user_o = static_cast(const_cast(output->flat().data())); mkl_context.relu_res[dnnResourceDst] = user_o; mkl_context.relu_res[dnnResourceSrc] = user_i; CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_fwd, mkl_context.relu_res), E_SUCCESS); mkl_context.MklCleanup(); } private: typedef struct { int in_dims; size_t* in_sizes; size_t* in_strides; MklShape input_shape, output_shape; dnnPrimitive_t prim_relu_fwd = nullptr; void* relu_res[dnnResourceNumber]; dnnLayout_t lt_input = nullptr; void MklCleanup() { bool input_in_mkl_format = input_shape.IsMklTensor(); if (!input_in_mkl_format) { dnnLayoutDelete_F32(lt_input); free(in_sizes); free(in_strides); } dnnDelete_F32(prim_relu_fwd); } void MklCreateInputLayouts(OpKernelContext* context) { bool input_in_mkl_format = input_shape.IsMklTensor(); if (!input_in_mkl_format) { CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), E_SUCCESS); } else { lt_input = static_cast(input_shape.GetCurLayout()); } } } MklReluOpContext; }; template class MklReluGradOp : public OpKernel { public: ~MklReluGradOp() {} explicit MklReluGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override; private: typedef struct { int in_dims; size_t* in_sizes; size_t* in_strides; MklShape input_shape, grad_shape, output_shape; void* relu_res[dnnResourceNumber]; dnnPrimitive_t prim_relu_bwd; dnnLayout_t lt_input, lt_grad; void MklPrepareReluGradInputs(OpKernelContext* context, Tensor* mkl_tmp_input_buf_tensor) { const Tensor& g = MklGetInput(context, 0); const Tensor& a = MklGetInput(context, 1); void* buf_input = static_cast(const_cast(a.flat().data())); void* mkl_buffer_convert = nullptr; dnnPrimitive_t cv_input_to_grad = nullptr; // if input and grad are not in the same layout, // do a conversion between them. if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad, &mkl_buffer_convert); CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), E_SUCCESS); CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input, mkl_buffer_convert), E_SUCCESS); relu_res[dnnResourceSrc] = mkl_buffer_convert; dnnDelete_F32(cv_input_to_grad); } else { relu_res[dnnResourceSrc] = buf_input; } void* buf_grad = static_cast(const_cast(g.flat().data())); relu_res[dnnResourceDiffDst] = buf_grad; } void MklCreateInputLayouts(OpKernelContext* context) { bool grad_is_mkl = grad_shape.IsMklTensor(); bool input_is_mkl = input_shape.IsMklTensor(); if (!input_is_mkl) { CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), E_SUCCESS); } else { lt_input = static_cast(input_shape.GetCurLayout()); } if (!grad_is_mkl) { CHECK_EQ(dnnLayoutCreate_F32(<_grad, in_dims, in_sizes, in_strides), E_SUCCESS); } else { lt_grad = static_cast(grad_shape.GetCurLayout()); } } void MklCleanup() { bool grad_is_mkl = grad_shape.IsMklTensor(); bool input_is_mkl = input_shape.IsMklTensor(); dnnDelete_F32(prim_relu_bwd); if (!input_is_mkl) { dnnLayoutDelete_F32(lt_input); free(in_sizes); free(in_strides); } if (!grad_is_mkl) { dnnLayoutDelete_F32(lt_grad); } } } MklReluGradOpContext; }; template void MklReluGradOp::Compute(OpKernelContext* context) { MklReluGradOpContext mkl_context; const Tensor& g = MklGetInput(context, 0); const Tensor& a = MklGetInput(context, 1); void* user_i = static_cast(const_cast(a.flat().data())); void* user_g = static_cast(const_cast(g.flat().data())); GetMklShape(context, 0, &mkl_context.grad_shape); GetMklShape(context, 1, &mkl_context.input_shape); bool grad_is_mkl = mkl_context.grad_shape.IsMklTensor(); bool input_is_mkl = mkl_context.input_shape.IsMklTensor(); if (!input_is_mkl && !grad_is_mkl && !MklReluHelpers::ValidateSameSize(context, g, a)) return; Tensor* output = nullptr; if (!input_is_mkl && !grad_is_mkl && !a.dims()) { // handle the scalar case const TensorShape& g_shape = g.shape(); mkl_context.output_shape.SetMklTensor(false); AllocateOutputSetMklShape(context, 0, &output, g_shape, mkl_context.output_shape); void* out_o = static_cast(output->flat().data()); (static_cast(out_o))[0] = (static_cast(user_g))[0] * ((static_cast(user_i))[0] > 0); return; } // generate size, stride for input if input/grad is in mkl format. if (grad_is_mkl || input_is_mkl) { const MklShape* tmp_mkl_shape = (grad_is_mkl) ? &mkl_context.grad_shape : &mkl_context.input_shape; mkl_context.in_dims = tmp_mkl_shape->GetDimension(); mkl_context.in_strides = new size_t[mkl_context.in_dims]; mkl_context.in_sizes = new size_t[mkl_context.in_dims]; for (int i = 0; i < mkl_context.in_dims; i++) { mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i]; mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i]; } } else { mkl_context.in_dims = g.dims(); mkl_context.in_strides = new size_t[mkl_context.in_dims]; mkl_context.in_sizes = new size_t[mkl_context.in_dims]; for (int i = 0; i < mkl_context.in_dims; i++) { mkl_context.in_sizes[i] = g.dim_size((mkl_context.in_dims - 1) - i); } mkl_context.in_strides[0] = 1; for (int i = 1; i < mkl_context.in_dims; i++) { mkl_context.in_strides[i] = mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1]; } } mkl_context.MklCreateInputLayouts(context); float negative_slope = 0.0; CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL, mkl_context.lt_grad, mkl_context.lt_grad, negative_slope), E_SUCCESS); Tensor mkl_tmp_input_buf_tensor; mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor); if (input_is_mkl || grad_is_mkl) { /*if grad or input are mkl leave it in mkl*/ TensorShape tf_shape; mkl_context.output_shape.SetMklTensor(true); mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_bwd, dnnResourceDiffSrc); mkl_context.output_shape.SetTfLayout( mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides); // if input_is_mkl or grad_is_mkl, then we copy strides and sizes from mkl // shape of one that is in mkl layout. if (grad_is_mkl == true) { mkl_context.output_shape.SetTfDimOrder( mkl_context.in_dims, mkl_context.grad_shape.GetTfToMklDimMap()); } else { mkl_context.output_shape.SetTfDimOrder( mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap()); } tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast( mkl_context.output_shape.GetMklLayout())) / sizeof(T)); AllocateOutputSetMklShape(context, 0, &output, tf_shape, mkl_context.output_shape); } else { const TensorShape& o_shape = g.shape(); mkl_context.output_shape.SetMklTensor(false); AllocateOutputSetMklShape(context, 0, &output, o_shape, mkl_context.output_shape); } mkl_context.relu_res[dnnResourceDiffSrc] = static_cast(output->flat().data()); CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, mkl_context.relu_res), E_SUCCESS); mkl_context.MklCleanup(); } #else // INTEL_MKL_ML_ONLY template class MklReluOpBase : public OpKernel { public: ~MklReluOpBase() {} explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) {} virtual void Compute_Scalar(OpKernelContext* context) = 0; void Compute(OpKernelContext* context) override { try { const size_t src_index = 0; // index of src input tensor const size_t dst_index = 0; // index of dst output tensor const Tensor& src_tensor = MklGetInput(context, src_index); MklDnnShape dnn_shape_src; GetMklShape(context, src_index, &dnn_shape_src); if (src_tensor.dims() == 0) { Compute_Scalar(context); return; } // Set DNN primitive - src MklDnnData src(&cpu_engine); memory::dims src_dims; memory::desc src_md({}, memory::data_undef, memory::format_undef); if (dnn_shape_src.IsMklTensor()) { src_md = dnn_shape_src.GetMklLayout(); src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); } else { src_dims = TFShapeToMklDnnDims(src_tensor.shape()); auto src_strides = CalculateTFStrides(src_dims); // Create blocked memory descriptor src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); } T alpha = 0, beta = 0; // get a eltwise fwd from primitive pool MklEltwiseFwdParams fwdParams(src_dims, src_md, alg_kind, alpha, beta); MklEltwiseFwdPrimitive* eltwise_fwd = MklEltwiseFwdPrimitiveFactory::Get(fwdParams); // prepare for execuation const T* src_data = src_tensor.flat().data(); // check wehther src need to reorder if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) { src.SetUsrMem(src_md, &src_tensor); auto src_target_pd = memory::primitive_desc( {{src_dims}, MklDnnType(), eltwise_fwd->GetSrcMemoryFormat()}, cpu_engine); src.CheckReorderToOpMem(src_target_pd); src_data = const_cast( reinterpret_cast(src.GetOpMem().get_data_handle())); } // allocate dst tensor, always set it as MKL-DNN layout std::shared_ptr eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd(); MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; if (dnn_shape_src.IsMklTensor()) { dnn_shape_dst.SetMklTensor(true); auto dst_pd = eltwise_fwd_pd->dst_primitive_desc(); dnn_shape_dst.SetMklLayout(&dst_pd); dnn_shape_dst.SetElemType(MklDnnType()); dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), dnn_shape_src.GetSizesAsMklDnnDims(), dnn_shape_src.GetTfDataFormat()); tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); } else { dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); } Tensor* dst_tensor = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {static_cast(src_index)}, static_cast(dst_index), tf_shape_dst, &dst_tensor)); AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); T* dst_data = dst_tensor->flat().data(); // execute eltwise eltwise_fwd->Execute(src_data, dst_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + string(__FILE__) + ":" + std::to_string(__LINE__); OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", error_msg)); } } private: engine cpu_engine = engine(engine::cpu, 0); std::shared_ptr relu_fwd_pd; }; template class MklReluGradOpBase : public OpKernel { public: ~MklReluGradOpBase() {} explicit MklReluGradOpBase(OpKernelConstruction* context) : OpKernel(context) { } virtual void Compute_Scalar(OpKernelContext* context) = 0; void Compute(OpKernelContext* context) { try { MklDnnData src(&cpu_engine); MklDnnData diff_dst(&cpu_engine); const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); Tensor* diff_src_tensor = nullptr; MklDnnShape dnn_shape_src, dnn_shape_diff_dst; GetMklShape(context, src_index, &dnn_shape_src); GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); int src_dims_size = src_tensor.dims(); if (src_dims_size == 0) { Compute_Scalar(context); return; } // get a eltwise bwd from primitive pool memory::dims src_dims = {}; memory::desc src_md({}, memory::data_undef, memory::format_undef); memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { src_dims = TFShapeToMklDnnDims(src_tensor.shape()); auto src_strides = CalculateTFStrides(src_dims); src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); diff_dst_md = src_md; } else if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { src_md = dnn_shape_src.GetMklLayout(); src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); auto src_tf_data_format = MklDnnDataFormatToTFDataFormat(src_mkl_data_format); auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), src_tf_data_format); diff_dst_md = memory::desc(diff_dst_dims, MklDnnType(), src_mkl_data_format); } else if (!dnn_shape_src.IsMklTensor() && dnn_shape_diff_dst.IsMklTensor()) { diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); memory::format diff_dst_mkl_data_format = dnn_shape_diff_dst.GetTfDataFormat(); auto diff_dst_tf_data_format = MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); src_dims = (src_tensor.dims() == 4) ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), diff_dst_tf_data_format) : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), diff_dst_tf_data_format); src_md = memory::desc(src_dims, MklDnnType(), diff_dst_mkl_data_format); } else { src_md = dnn_shape_src.GetMklLayout(); diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); } T alpha = 0, beta = 0; // As per comment above, we tell MKLDNN that both the inputs are in same // format. So we set common memory descriptor in MKL format, if any of the // inputs are in MKL format. Let's get memory descriptor that we will use // for both the inputs. memory::desc common_md({}, memory::data_undef, memory::format_undef); if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md; } else { // Since both the inputs are in Tensorflow format, and have // same shape, we can get memory descriptor from any input. common_md = src_md; } MklEltwiseBwdParams bwdParams(src_dims, common_md, alg_kind, alpha, beta); MklEltwiseBwdPrimitive* eltwise_bwd = MklEltwiseBwdPrimitiveFactory::Get(bwdParams); auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd(); // check whether need reorder for src / diff_dst const T* src_data = src_tensor.flat().data(); if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) { src.SetUsrMem(src_md, &src_tensor); src.CheckReorderToOpMem( eltwise_bwd_pd.get()->diff_src_primitive_desc()); src_data = const_cast( reinterpret_cast(src.GetOpMem().get_data_handle())); } const T* diff_dst_data = diff_dst_tensor.flat().data(); if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); diff_dst.CheckReorderToOpMem( eltwise_bwd_pd.get()->diff_src_primitive_desc()); diff_dst_data = const_cast( reinterpret_cast(diff_dst.GetOpMem().get_data_handle())); } // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc(); dnn_shape_diff_src.SetMklTensor(true); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType()); if (dnn_shape_src.IsMklTensor()) { dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), dnn_shape_src.GetSizesAsMklDnnDims(), dnn_shape_src.GetTfDataFormat()); } else { dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(), dnn_shape_diff_dst.GetSizesAsMklDnnDims(), dnn_shape_diff_dst.GetTfDataFormat()); } tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); } else { dnn_shape_diff_src.SetMklTensor(false); tf_shape_diff_src = src_tensor.shape(); } OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {diff_dst_index}, diff_src_index, tf_shape_diff_src, &diff_src_tensor)); AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src); T* diff_src_data = diff_src_tensor->flat().data(); // execute eltwise bwd eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + string(__FILE__) + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); } } private: engine cpu_engine = engine(engine::cpu, 0); std::shared_ptr relu_fwd_pd; }; template class MklReluOp : public MklReluOpBase { public: ~MklReluOp() {} explicit MklReluOp(OpKernelConstruction* context) : MklReluOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor const size_t dst_index = 0; // index of dst output tensor const Tensor& src_tensor = MklGetInput(context, src_index); MklDnnShape dnn_shape_src; GetMklShape(context, src_index, &dnn_shape_src); Tensor* dst_tensor = nullptr; void* user_i = static_cast(const_cast(src_tensor.flat().data())); MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, dst_index, &dst_tensor, src_tensor.shape(), dnn_shape_dst); void* out_o = static_cast(dst_tensor->flat().data()); (static_cast(out_o))[0] = std::max((static_cast(user_i))[0], static_cast(0)); return; } }; template class MklReluGradOp : public MklReluGradOpBase { public: ~MklReluGradOp() {} explicit MklReluGradOp(OpKernelConstruction* context) : MklReluGradOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); Tensor* diff_src_tensor = nullptr; MklDnnShape dnn_shape_diff_dst; GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); MklDnnShape dnn_shape_diff_src; dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, diff_dst_tensor.shape(), dnn_shape_diff_src); void* out_o = static_cast(diff_src_tensor->flat().data()); void* user_i = static_cast(const_cast(src_tensor.flat().data())); void* user_g = static_cast(const_cast(diff_dst_tensor.flat().data())); (static_cast(out_o))[0] = (static_cast(user_g))[0] * ((static_cast(user_i))[0] > 0); return; } }; template class MklEluOp : public MklReluOpBase { public: ~MklEluOp() {} explicit MklEluOp(OpKernelConstruction* context) : MklReluOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor const size_t dst_index = 0; // index of dst output tensor const Tensor& src_tensor = MklGetInput(context, src_index); MklDnnShape dnn_shape_src; GetMklShape(context, src_index, &dnn_shape_src); Tensor* dst_tensor = nullptr; void* user_i = static_cast(const_cast(src_tensor.flat().data())); MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, dst_index, &dst_tensor, src_tensor.shape(), dnn_shape_dst); void* out_o = static_cast(dst_tensor->flat().data()); // return exp(feature) - 1 if feature > 0; feature otherwise T feature = (static_cast(user_i))[0]; if (feature < 0) (static_cast(out_o))[0] = std::exp(feature); else (static_cast(out_o))[0] = feature; return; } }; template class MklEluGradOp : public MklReluGradOpBase { public: ~MklEluGradOp() {} explicit MklEluGradOp(OpKernelConstruction* context) : MklReluGradOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); Tensor* diff_src_tensor = nullptr; MklDnnShape dnn_shape_diff_dst; GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); MklDnnShape dnn_shape_diff_src; dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, diff_dst_tensor.shape(), dnn_shape_diff_src); void* out_o = static_cast(diff_src_tensor->flat().data()); void* user_i = static_cast(const_cast(src_tensor.flat().data())); void* user_g = static_cast(const_cast(diff_dst_tensor.flat().data())); // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise T feature = (static_cast(user_i))[0]; if (feature > 0) { (static_cast(out_o))[0] = (static_cast(user_g))[0]; } else { T elu = std::exp(feature) - 1; (static_cast(out_o))[0] = (static_cast(user_g))[0] * (elu + 1); } } }; template class MklTanhOp : public MklReluOpBase { public: ~MklTanhOp() {} explicit MklTanhOp(OpKernelConstruction* context) : MklReluOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor const size_t dst_index = 0; // index of dst output tensor const Tensor& src_tensor = MklGetInput(context, src_index); MklDnnShape dnn_shape_src; GetMklShape(context, src_index, &dnn_shape_src); Tensor* dst_tensor = nullptr; void* user_i = static_cast(const_cast(src_tensor.flat().data())); MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, dst_index, &dst_tensor, src_tensor.shape(), dnn_shape_dst); void* out_o = static_cast(dst_tensor->flat().data()); // tanh(x) = (e^x - e^(-x))/ (e^x + e^(-x)) T feature = (static_cast(user_i))[0]; T e1 = std::exp(feature); T e2 = std::exp(-feature); (static_cast(out_o))[0] = (e1 - e2) / (e1 + e2); return; } }; template class MklTanhGradOp : public MklReluGradOpBase { public: ~MklTanhGradOp() {} explicit MklTanhGradOp(OpKernelConstruction* context) : MklReluGradOpBase(context) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t src_index = 1; // index of src input tensor const size_t diff_src_index = 0; // index of diff_src output tensor const Tensor& src_tensor = MklGetInput(context, src_index); const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); Tensor* diff_src_tensor = nullptr; MklDnnShape dnn_shape_diff_dst; GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); MklDnnShape dnn_shape_diff_src; dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, diff_dst_tensor.shape(), dnn_shape_diff_src); void* out_o = static_cast(diff_src_tensor->flat().data()); void* user_i = static_cast(const_cast(src_tensor.flat().data())); // gradient of tanh(x) = 1 - tanh(x)^2 T feature = (static_cast(user_i))[0]; T e1 = std::exp(feature); T e2 = std::exp(-feature); T tanh = (e1 - e2) / (e1 + e2); void* user_g = static_cast(const_cast(diff_dst_tensor.flat().data())); (static_cast(out_o))[0] = (static_cast(user_g))[0] * (1 - tanh * tanh); } }; #endif // register dnn kernels for supported operations and supported types #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklReluOp); \ REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklReluGradOp); TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); #ifndef INTEL_MKL_ML_ONLY // register dnn kernels for supported operations and supported types #define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklElu") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklEluOp); \ REGISTER_KERNEL_BUILDER(Name("_MklEluGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklEluGradOp); TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); #define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklTanh") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklTanhOp); \ REGISTER_KERNEL_BUILDER(Name("_MklTanhGrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklTanhGradOp); TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); #endif } // namespace tensorflow #endif // INTEL_MKL