diff options
author | 2016-08-08 09:06:45 -0700 | |
---|---|---|
committer | 2016-08-08 09:06:45 -0700 | |
commit | 72096f3bd4e1db96bae816affffcc2544a5fd005 (patch) | |
tree | 584f6fae3fcecaf58306be3122f7597fa9ca03c8 | |
parent | 3e4a33d4bac400f857fc165c9b119901c1c7f5e5 (diff) | |
parent | 1031223c095c7685347b4930e81b390ee88c35e0 (diff) |
Merged in suiyuan2009/eigen/fix_tanh_inconsistent_for_tensorflow (pull request PR-215)
Fix_tanh_inconsistent_for_tensorflow
-rw-r--r-- | Eigen/src/Core/functors/UnaryFunctors.h | 117 |
1 files changed, 89 insertions, 28 deletions
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 04208c9fe..e2f3d869f 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -491,19 +491,62 @@ struct functor_traits<scalar_atan_op<Scalar> > }; }; - /** \internal * \brief Template functor to compute the tanh of a scalar * \sa class CwiseUnaryOp, ArrayBase::tanh() */ -template<typename Scalar> struct scalar_tanh_op { +template <typename Scalar> +struct scalar_tanh_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_op) - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::tanh(a); } + EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { + /** \internal \returns the hyperbolic tan of \a a (coeff-wise) + Doesn't do anything fancy, just a 13/6-degree rational interpolant + which + is accurate up to a couple of ulp in the range [-9, 9], outside of + which + the fl(tanh(x)) = +/-1. */ + + // Clamp the inputs to the range [-9, 9] since anything outside + // this range is +/-1.0f in single-precision. + const Scalar plus_9 = static_cast<Scalar>(9.0); + const Scalar minus_9 = static_cast<Scalar>(-9.0); + const Scalar x = numext::maxi(minus_9, numext::mini(plus_9, a)); + // Scalarhe monomial coefficients of the numerator polynomial (odd). + const Scalar alpha_1 = static_cast<Scalar>(4.89352455891786e-03); + const Scalar alpha_3 = static_cast<Scalar>(6.37261928875436e-04); + const Scalar alpha_5 = static_cast<Scalar>(1.48572235717979e-05); + const Scalar alpha_7 = static_cast<Scalar>(5.12229709037114e-08); + const Scalar alpha_9 = static_cast<Scalar>(-8.60467152213735e-11); + const Scalar alpha_11 = static_cast<Scalar>(2.00018790482477e-13); + const Scalar alpha_13 = static_cast<Scalar>(-2.76076847742355e-16); + // Scalarhe monomial coefficients of the denominator polynomial (even). + const Scalar beta_0 = static_cast<Scalar>(4.89352518554385e-03); + const Scalar beta_2 = static_cast<Scalar>(2.26843463243900e-03); + const Scalar beta_4 = static_cast<Scalar>(1.18534705686654e-04); + const Scalar beta_6 = static_cast<Scalar>(1.19825839466702e-06); + // Since the polynomials are odd/even, we need x^2. + const Scalar x2 = x * x; + // Evaluate the numerator polynomial p. + Scalar p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + // Evaluate the denominator polynomial p. + Scalar q = x2 * beta_6 + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + // Divide the numerator by the denominator. + return p / q; + } template <typename Packet> EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& _x) const { /** \internal \returns the hyperbolic tan of \a a (coeff-wise) Doesn't do anything fancy, just a 13/6-degree rational interpolant which - is accurate up to a couple of ulp in the range [-9, 9], outside of which the + is accurate up to a couple of ulp in the range [-9, 9], outside of which + the fl(tanh(x)) = +/-1. */ // Clamp the inputs to the range [-9, 9] since anything outside @@ -511,7 +554,7 @@ template<typename Scalar> struct scalar_tanh_op { const Packet plus_9 = pset1<Packet>(9.0); const Packet minus_9 = pset1<Packet>(-9.0); const Packet x = pmax(minus_9, pmin(plus_9, _x)); - + // The monomial coefficients of the numerator polynomial (odd). const Packet alpha_1 = pset1<Packet>(4.89352455891786e-03); const Packet alpha_3 = pset1<Packet>(6.37261928875436e-04); @@ -520,17 +563,17 @@ template<typename Scalar> struct scalar_tanh_op { const Packet alpha_9 = pset1<Packet>(-8.60467152213735e-11); const Packet alpha_11 = pset1<Packet>(2.00018790482477e-13); const Packet alpha_13 = pset1<Packet>(-2.76076847742355e-16); - + // The monomial coefficients of the denominator polynomial (even). const Packet beta_0 = pset1<Packet>(4.89352518554385e-03); const Packet beta_2 = pset1<Packet>(2.26843463243900e-03); const Packet beta_4 = pset1<Packet>(1.18534705686654e-04); const Packet beta_6 = pset1<Packet>(1.19825839466702e-06); - + // Since the polynomials are odd/even, we need x^2. const Packet x2 = pmul(x, x); - - // Evaluate the numerator polynomial p. + + // Evaluate the numerator polynomial p. Packet p = pmadd(x2, alpha_13, alpha_11); p = pmadd(x2, p, alpha_9); p = pmadd(x2, p, alpha_7); @@ -538,38 +581,56 @@ template<typename Scalar> struct scalar_tanh_op { p = pmadd(x2, p, alpha_3); p = pmadd(x2, p, alpha_1); p = pmul(x, p); - + // Evaluate the denominator polynomial p. Packet q = pmadd(x2, beta_6, beta_4); q = pmadd(x2, q, beta_2); q = pmadd(x2, q, beta_0); - + // Divide the numerator by the denominator. return pdiv(p, q); } }; -template<typename Scalar> -struct functor_traits<scalar_tanh_op<Scalar> > -{ +template <> +struct scalar_tanh_op<std::complex<double> > { + EIGEN_DEVICE_FUNC inline const std::complex<double> operator()( + const std::complex<double>& a) const { + return numext::tanh(a); + } +}; +template <> +struct scalar_tanh_op<std::complex<float> > { + EIGEN_DEVICE_FUNC inline const std::complex<float> operator()( + const std::complex<float>& a) const { + return numext::tanh(a); + } +}; +template <typename Scalar> +struct functor_traits<scalar_tanh_op<Scalar> > { enum { PacketAccess = packet_traits<Scalar>::HasTanh, - Cost = - (PacketAccess - // The following numbers are based on the AVX implementation, + Cost = (PacketAccess && (!is_same<Scalar, std::complex<float> >::value) && + (!is_same<Scalar, std::complex<double> >::value) +// The following numbers are based on the AVX implementation, #ifdef EIGEN_VECTORIZE_FMA - // Haswell can issue 2 add/mul/madd per cycle. - // 9 pmadd, 2 pmul, 1 div, 2 other - ? (2 * NumTraits<Scalar>::AddCost + 6 * NumTraits<Scalar>::MulCost + - NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost) + // Haswell can issue 2 add/mul/madd per cycle. + // 9 pmadd, 2 pmul, 1 div, 2 other + ? (2 * NumTraits<Scalar>::AddCost + + 6 * NumTraits<Scalar>::MulCost + + NumTraits<Scalar>::template Div< + packet_traits<Scalar>::HasDiv>::Cost) #else - ? (11 * NumTraits<Scalar>::AddCost + - 11 * NumTraits<Scalar>::MulCost + - NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost) + ? (11 * NumTraits<Scalar>::AddCost + + 11 * NumTraits<Scalar>::MulCost + + NumTraits<Scalar>::template Div< + packet_traits<Scalar>::HasDiv>::Cost) #endif - // This number assumes a naive implementation of tanh - : (6 * NumTraits<Scalar>::AddCost + 3 * NumTraits<Scalar>::MulCost + - 2 * NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost + - functor_traits<scalar_exp_op<Scalar> >::Cost)) + // This number assumes a naive implementation of tanh + : (6 * NumTraits<Scalar>::AddCost + + 3 * NumTraits<Scalar>::MulCost + + 2 * NumTraits<Scalar>::template Div< + packet_traits<Scalar>::HasDiv>::Cost + + functor_traits<scalar_exp_op<Scalar> >::Cost)) }; }; |