diff options
author | 2017-12-15 17:12:41 -0800 | |
---|---|---|
committer | 2017-12-15 17:16:29 -0800 | |
commit | d55f532867a3670d66460c5ee3b774519542adc1 (patch) | |
tree | 7de4d85bcd61e93401459276b4d371ab0be23c1f /tensorflow/core/kernels/mkl_relu_op.cc | |
parent | 32d5048ae96116202f2aa0fa739ef37514ee8a54 (diff) |
Merge changes from github.
PiperOrigin-RevId: 179258973
Diffstat (limited to 'tensorflow/core/kernels/mkl_relu_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 505 |
1 files changed, 484 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 86a77d769a..45bdd0ad5c 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -28,6 +28,19 @@ limitations under the License. #include "mkl_dnn.h" #include "mkl_dnn_types.h" +#ifdef INTEL_MKL_DNN +#include "mkldnn.hpp" + +using mkldnn::stream; +using mkldnn::prop_kind; +using mkldnn::algorithm; +using mkldnn::relu_forward; +using mkldnn::relu_backward; +using mkldnn::eltwise_relu; +using mkldnn::eltwise_elu; +using mkldnn::eltwise_tanh; +#endif + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -45,6 +58,8 @@ struct MklReluHelpers { } }; +#ifndef INTEL_MKL_DNN + template <typename Device, typename T> class MklReluOp : public OpKernel { public: @@ -59,6 +74,7 @@ class MklReluOp : public OpKernel { GetMklShape(context, 0, &mkl_context.input_shape); void* user_i = static_cast<void*>(const_cast<T*>(input.flat<T>().data())); bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); + if (!input_in_mkl_format && !input.dims()) { // handle the case of a scalar const TensorShape& o_shape = input.shape(); Tensor* out_tensor = nullptr; @@ -164,6 +180,7 @@ class MklReluOp : public OpKernel { } MklReluOpContext; }; + template <typename Device, typename T> class MklReluGradOp : public OpKernel { public: @@ -189,18 +206,18 @@ class MklReluGradOp : public OpKernel { const Tensor& a = MklGetInput(context, 1); void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data())); void* mkl_buffer_convert = nullptr; + dnnPrimitive_t cv_input_to_grad = nullptr; - // if input and grad are not in the same layout, do a conversion between - // them. + // if input and grad are not in the same layout, + // do a conversion between them. if (!dnnLayoutCompare_F32(lt_input, lt_grad)) { AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad, &mkl_buffer_convert); CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad), E_SUCCESS); CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input, - mkl_buffer_convert), - E_SUCCESS); + mkl_buffer_convert), E_SUCCESS); relu_res[dnnResourceSrc] = mkl_buffer_convert; dnnDelete_F32(cv_input_to_grad); } else { @@ -246,7 +263,6 @@ class MklReluGradOp : public OpKernel { }; template <typename Device, typename T> - void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { MklReluGradOpContext mkl_context; const Tensor& g = MklGetInput(context, 0); @@ -264,20 +280,21 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { !MklReluHelpers::ValidateSameSize(context, g, a)) return; Tensor* output = nullptr; - if (!input_is_mkl && !grad_is_mkl && - !a.dims()) { // handle the case of a scalar - // Allocate space for g and + + if (!input_is_mkl && !grad_is_mkl && !a.dims()) { + // handle the scalar case const TensorShape& g_shape = g.shape(); mkl_context.output_shape.SetMklTensor(false); AllocateOutputSetMklShape(context, 0, &output, g_shape, mkl_context.output_shape); + void* out_o = static_cast<void*>(output->flat<T>().data()); (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0); return; } - // Generate size, stride for input if input/grad is in MKL format. + // generate size, stride for input if input/grad is in mkl format. if (grad_is_mkl || input_is_mkl) { const MklShape* tmp_mkl_shape = (grad_is_mkl) ? &mkl_context.grad_shape : &mkl_context.input_shape; @@ -308,21 +325,20 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { float negative_slope = 0.0; CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL, mkl_context.lt_grad, mkl_context.lt_grad, - negative_slope), - E_SUCCESS); + negative_slope), E_SUCCESS); Tensor mkl_tmp_input_buf_tensor; mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor); if (input_is_mkl || - grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/ + grad_is_mkl) { /*if grad or input are mkl leave it in mkl*/ TensorShape tf_shape; mkl_context.output_shape.SetMklTensor(true); mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_bwd, dnnResourceDiffSrc); mkl_context.output_shape.SetTfLayout( mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides); - // If input_is_mkl or grad_is_mkl, then we copy strides and sizes from Mkl - // shape of one that is in MKL layout. + // if input_is_mkl or grad_is_mkl, then we copy strides and sizes from mkl + // shape of one that is in mkl layout. if (grad_is_mkl == true) { mkl_context.output_shape.SetTfDimOrder( mkl_context.in_dims, mkl_context.grad_shape.GetTfToMklDimMap()); @@ -332,11 +348,9 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { } tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( - mkl_context.output_shape.GetMklLayout())) / - sizeof(T)); + mkl_context.output_shape.GetMklLayout())) / sizeof(T)); AllocateOutputSetMklShape(context, 0, &output, tf_shape, mkl_context.output_shape); - } else { const TensorShape& o_shape = g.shape(); mkl_context.output_shape.SetMklTensor(false); @@ -347,13 +361,430 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { mkl_context.relu_res[dnnResourceDiffSrc] = static_cast<void*>(output->flat<T>().data()); - CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, mkl_context.relu_res), - E_SUCCESS); + CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, + mkl_context.relu_res), + E_SUCCESS); mkl_context.MklCleanup(); } -/* Register DNN kernels for supported operations and supported types - right now - * it is only Relu and f32*/ + +#else // INTEL_MKL_DNN + +template <typename Device, typename T, algorithm alg_kind> +class MklReluOpBase : public OpKernel { + public: + ~MklReluOpBase() {} + + explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) { + } + + virtual void Compute_Scalar(OpKernelContext* context) = 0; + + void Compute(OpKernelContext* context) override { + try { + auto cpu_engine = engine(engine::cpu, 0); + const size_t src_index = 0; // index of src input tensor + const size_t dst_index = 0; // index of dst output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + MklDnnShape dnn_shape_src; + GetMklShape(context, src_index, &dnn_shape_src); + + Tensor* dst_tensor = nullptr; + if (src_tensor.dims() == 0) { + Compute_Scalar(context); + return; + } + + // Create relu primitive. + MklDnnData<T> src(&cpu_engine); + MklDnnData<T> dst(&cpu_engine); + + // Set DNN primitive - src + memory::desc src_md({}, memory::data_undef, memory::format_undef); + if (dnn_shape_src.IsMklTensor()) { + src_md = dnn_shape_src.GetMklLayout(); + } else { + auto src_dims = TFShapeToMklDnnDims(src_tensor.shape()); + auto src_strides = CalculateTFStrides(src_dims); + // Create blocked memory descriptor + src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides); + } + src.SetUsrMem(src_md, &src_tensor); + + T alpha = 0, beta = 0; + std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd; + auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training, + // Operator memory descriptor is same as user memory descriptor. + alg_kind, src.GetUsrMemDesc(), + alpha, beta); + relu_fwd_pd.reset(new relu_forward::primitive_desc(relu_fwd_desc, + cpu_engine)); + + // allocate dst tensor + MklDnnShape dnn_shape_dst; + TensorShape tf_shape_dst; + if (dnn_shape_src.IsMklTensor()) { + dnn_shape_dst.SetMklTensor(true); + auto dst_pd = relu_fwd_pd->dst_primitive_desc(); + dnn_shape_dst.SetMklLayout(&dst_pd); + dnn_shape_dst.SetElemType(MklDnnType<T>()); + dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), + dnn_shape_src.GetSizesAsMklDnnDims(), + dnn_shape_src.GetTfDataFormat()); + tf_shape_dst.AddDim(dst_pd.get_size()/sizeof(T)); + } else { + dnn_shape_dst.SetMklTensor(false); + tf_shape_dst = src_tensor.shape(); + } + AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst, + dnn_shape_dst); + + // Destination memory descriptor is same as source memory descriptor. + auto dst_md = src_md; + dst.SetUsrMem(dst_md, dst_tensor); + + // execute net + std::vector<primitive> net; + auto relu_fwd = relu_forward(*relu_fwd_pd, src.GetOpMem(), + dst.GetOpMem()); + net.push_back(relu_fwd); + stream(stream::kind::eager).submit(net).wait(); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); + } + } +}; + + +template <typename Device, typename T, algorithm alg_kind> +class MklReluGradOpBase : public OpKernel { + public: + ~MklReluGradOpBase() {} + + explicit MklReluGradOpBase(OpKernelConstruction* context) : + OpKernel(context) {} + + virtual void Compute_Scalar(OpKernelContext* context) = 0; + + void Compute(OpKernelContext* context) { + try { + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<T> src(&cpu_engine); + MklDnnData<T> diff_dst(&cpu_engine); + MklDnnData<T> diff_src(&cpu_engine); + + const size_t diff_dst_index = 0; // index of diff_dst input tensor + const size_t src_index = 1; // index of src input tensor + const size_t diff_src_index = 0; // index of diff_src output tensor + + const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); + Tensor* diff_src_tensor = nullptr; + + MklDnnShape dnn_shape_src, dnn_shape_diff_dst; + GetMklShape(context, src_index, &dnn_shape_src); + GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); + + int src_dims_size = src_tensor.dims(); + if (src_dims_size == 0) { + Compute_Scalar(context); + return; + } + + // Set DNN primitives for src & diff_dst + memory::desc src_md({}, memory::data_undef, memory::format_undef); + memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); + if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { + if (dnn_shape_diff_dst.IsMklTensor()) { + diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); + src_md = diff_dst_md; + } else { + src_md = dnn_shape_src.GetMklLayout(); + diff_dst_md = src_md; + } + } else { + auto src_dims = TFShapeToMklDnnDims(src_tensor.shape()); + auto src_strides = CalculateTFStrides(src_dims); + src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides); + diff_dst_md = src_md; + } + src.SetUsrMem(src_md, &src_tensor); + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + + T alpha = 0, beta = 0; + std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd; + auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training, + alg_kind, src_md, alpha, beta); + relu_fwd_pd.reset(new relu_forward::primitive_desc(relu_fwd_desc, + cpu_engine)); + auto relu_bwd_desc = relu_backward::desc(alg_kind, diff_dst_md, src_md, + alpha, beta); + auto relu_bwd_pd = relu_backward::primitive_desc(relu_bwd_desc, + cpu_engine, *relu_fwd_pd); + + // allocate diff_src tensor + MklDnnShape dnn_shape_diff_src; + TensorShape tf_shape_diff_src; + if (dnn_shape_src.IsMklTensor()) { + dnn_shape_diff_src.SetMklTensor(true); + auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc(); + dnn_shape_diff_src.SetMklLayout(&diff_src_pd); + dnn_shape_diff_src.SetElemType(MklDnnType<T>()); + dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), + dnn_shape_src.GetSizesAsMklDnnDims(), + dnn_shape_src.GetTfDataFormat()); + tf_shape_diff_src.AddDim(diff_src_pd.get_size()/sizeof(T)); + } else { + dnn_shape_diff_src.SetMklTensor(false); + tf_shape_diff_src = src_tensor.shape(); + } + AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, + tf_shape_diff_src, dnn_shape_diff_src); + + // diff_src memory descriptor is same as diff_dst memory descriptor. + auto diff_src_md = diff_dst_md; + diff_src.SetUsrMem(diff_src_md, diff_src_tensor); + + PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); + } + } + + void PrepareAndExecuteNet(const relu_backward::primitive_desc& relu_prim_desc, + MklDnnData<T>* src, MklDnnData<T>* diff_src, MklDnnData<T>* + diff_dst) { + std::vector<primitive> net; + net.push_back(relu_backward(relu_prim_desc, src->GetOpMem(), + diff_dst->GetOpMem(), diff_src->GetOpMem())); + stream(stream::kind::eager).submit(net).wait(); + } +}; + + +template <typename Device, typename T> +class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> { + public: + ~MklReluOp() {} + + explicit MklReluOp(OpKernelConstruction* context) : + MklReluOpBase<Device, T, eltwise_relu>(context) {} + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t src_index = 0; // index of src input tensor + const size_t dst_index = 0; // index of dst output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + MklDnnShape dnn_shape_src; + GetMklShape(context, src_index, &dnn_shape_src); + + Tensor* dst_tensor = nullptr; + void* user_i = static_cast<void*>(const_cast<T*>( + src_tensor.flat<T>().data())); + MklDnnShape dnn_shape_dst; + dnn_shape_dst.SetMklTensor(false); + AllocateOutputSetMklShape(context, dst_index, &dst_tensor, + src_tensor.shape(), dnn_shape_dst); + void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); + (static_cast<T*>(out_o))[0] = + std::max((static_cast<T*>(user_i))[0], static_cast<T>(0)); + return; + } +}; + +template <typename Device, typename T> +class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> { + public: + ~MklReluGradOp() {} + + explicit MklReluGradOp(OpKernelConstruction* context) : + MklReluGradOpBase<Device, T, eltwise_relu>(context) {} + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t diff_dst_index = 0; // index of diff_dst input tensor + const size_t src_index = 1; // index of src input tensor + const size_t diff_src_index = 0; // index of diff_src output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); + Tensor* diff_src_tensor = nullptr; + + MklDnnShape dnn_shape_diff_dst; + GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); + + int src_dims_size = src_tensor.dims(); + MklDnnShape dnn_shape_diff_src; + dnn_shape_diff_src.SetMklTensor(false); + AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, + diff_dst_tensor.shape(), dnn_shape_diff_src); + void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); + void* user_i = + static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); + void* user_g = + static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); + (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0] * + ((static_cast<T*>(user_i))[0] > 0); + return; + } +}; + +template <typename Device, typename T> +class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> { + public: + ~MklEluOp() {} + + explicit MklEluOp(OpKernelConstruction* context) : + MklReluOpBase<Device, T, eltwise_elu>(context) {} + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t src_index = 0; // index of src input tensor + const size_t dst_index = 0; // index of dst output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + MklDnnShape dnn_shape_src; + GetMklShape(context, src_index, &dnn_shape_src); + + Tensor* dst_tensor = nullptr; + void* user_i = static_cast<void*>(const_cast<T*>( + src_tensor.flat<T>().data())); + MklDnnShape dnn_shape_dst; + dnn_shape_dst.SetMklTensor(false); + AllocateOutputSetMklShape(context, dst_index, &dst_tensor, + src_tensor.shape(), dnn_shape_dst); + void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); + // return exp(feature) - 1 if feature > 0; feature otherwise + T feature = (static_cast<T*>(user_i))[0]; + if (feature < 0) + (static_cast<T*>(out_o))[0] = std::exp(feature); + else + (static_cast<T*>(out_o))[0] = feature; + return; + } +}; + +template <typename Device, typename T> +class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> { + public: + ~MklEluGradOp() {} + + explicit MklEluGradOp(OpKernelConstruction* context) : + MklReluGradOpBase<Device, T, eltwise_elu>(context) {} + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t diff_dst_index = 0; // index of diff_dst input tensor + const size_t src_index = 1; // index of src input tensor + const size_t diff_src_index = 0; // index of diff_src output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); + Tensor* diff_src_tensor = nullptr; + + MklDnnShape dnn_shape_diff_dst; + GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); + + int src_dims_size = src_tensor.dims(); + MklDnnShape dnn_shape_diff_src; + dnn_shape_diff_src.SetMklTensor(false); + AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, + diff_dst_tensor.shape(), dnn_shape_diff_src); + void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); + void* user_i = + static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); + void* user_g = + static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); + // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise + T feature = (static_cast<T*>(user_i))[0]; + if (feature > 0) { + (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0]; + } else { + T elu = std::exp(feature) - 1; + (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0] * (elu + 1); + } + } +}; + +template <typename Device, typename T> +class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> { + public: + ~MklTanhOp() {} + + explicit MklTanhOp(OpKernelConstruction* context) : + MklReluOpBase<Device, T, eltwise_tanh>(context) {} + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t src_index = 0; // index of src input tensor + const size_t dst_index = 0; // index of dst output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + MklDnnShape dnn_shape_src; + GetMklShape(context, src_index, &dnn_shape_src); + + Tensor* dst_tensor = nullptr; + void* user_i = static_cast<void*>(const_cast<T*>( + src_tensor.flat<T>().data())); + MklDnnShape dnn_shape_dst; + dnn_shape_dst.SetMklTensor(false); + AllocateOutputSetMklShape(context, dst_index, &dst_tensor, + src_tensor.shape(), dnn_shape_dst); + void* out_o = static_cast<void*>(dst_tensor->flat<T>().data()); + // tanh(x) = (e^x - e^(-x))/ (e^x + e^(-x)) + T feature = (static_cast<T*>(user_i))[0]; + T e1 = std::exp(feature); + T e2 = std::exp(-feature); + (static_cast<T*>(out_o))[0] = (e1 - e2)/(e1 + e2); + return; + } +}; + +template <typename Device, typename T> +class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> { + public: + ~MklTanhGradOp() {} + + explicit MklTanhGradOp(OpKernelConstruction* context) : + MklReluGradOpBase<Device, T, eltwise_tanh>(context) {} + + virtual void Compute_Scalar(OpKernelContext* context) { + const size_t diff_dst_index = 0; // index of diff_dst input tensor + const size_t src_index = 1; // index of src input tensor + const size_t diff_src_index = 0; // index of diff_src output tensor + const Tensor& src_tensor = MklGetInput(context, src_index); + const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); + Tensor* diff_src_tensor = nullptr; + + MklDnnShape dnn_shape_diff_dst; + GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); + + int src_dims_size = src_tensor.dims(); + MklDnnShape dnn_shape_diff_src; + dnn_shape_diff_src.SetMklTensor(false); + AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, + diff_dst_tensor.shape(), dnn_shape_diff_src); + void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data()); + void* user_i = + static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data())); + // gradient of tanh(x) = 1 - tanh(x)^2 + T feature = (static_cast<T*>(user_i))[0]; + T e1 = std::exp(feature); + T e2 = std::exp(-feature); + T tanh = (e1 - e2)/(e1 + e2); + void* user_g = + static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); + (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0] * + (1 - tanh * tanh); + } +}; + +#endif + +// register dnn kernels for supported operations and supported types #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER(Name("_MklRelu") \ .Device(DEVICE_CPU) \ @@ -367,6 +798,38 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { MklReluGradOp<CPUDevice, type>); TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); +#ifdef INTEL_MKL_DNN + +// register dnn kernels for supported operations and supported types +#define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER(Name("_MklElu") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklEluOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("_MklEluGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklEluGradOp<CPUDevice, type>); +TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); + +#define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER(Name("_MklTanh") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklTanhOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("_MklTanhGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklTanhGradOp<CPUDevice, type>); +TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); + +#endif + } // namespace tensorflow #endif // INTEL_MKL + |