From 06a22ca5bd0e34689d4fcff0cdd6edfd1ee76d2d Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 17 Jul 2015 09:29:00 -0700 Subject: Added support for sigmoid function to the tensor module --- unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 6 +++++ .../Eigen/CXX11/src/Tensor/TensorFunctors.h | 30 ++++++++++++++++++++++ 2 files changed, 36 insertions(+) (limited to 'unsupported/Eigen') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 71f69b568..0e5e4b426 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -116,6 +116,12 @@ class TensorBase return unaryExpr(internal::scalar_tanh_op()); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + sigmoid() const { + return unaryExpr(internal::scalar_sigmoid_op()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> exp() const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index 14ffd5c93..1a23a9fa0 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -13,6 +13,36 @@ namespace Eigen { namespace internal { + +/** \internal + * \brief Template functor to compute the sigmoid of a scalar + * \sa class CwiseUnaryOp, ArrayBase::sigmoid() + */ +template +struct scalar_sigmoid_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { + const T one = T(1); + return one / (one + std::exp(-x)); + } + + template + inline Packet packetOp(const Packet& x) const { + const Packet one = pset1(1); + return pdiv(one, padd(one, pexp(pnegate(x)))); + } +}; + +template +struct functor_traits > { + enum { + Cost = NumTraits::AddCost * 2 + NumTraits::MulCost * 6, + PacketAccess = packet_traits::HasAdd && packet_traits::HasDiv && + packet_traits::HasNegate && packet_traits::HasExp + }; +}; + + // Standard reduction functors template struct SumReducer { -- cgit v1.2.3