aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-08-08 09:06:45 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-08-08 09:06:45 -0700
commit72096f3bd4e1db96bae816affffcc2544a5fd005 (patch)
tree584f6fae3fcecaf58306be3122f7597fa9ca03c8
parent3e4a33d4bac400f857fc165c9b119901c1c7f5e5 (diff)
parent1031223c095c7685347b4930e81b390ee88c35e0 (diff)
Merged in suiyuan2009/eigen/fix_tanh_inconsistent_for_tensorflow (pull request PR-215)
Fix_tanh_inconsistent_for_tensorflow
-rw-r--r--Eigen/src/Core/functors/UnaryFunctors.h117
1 files changed, 89 insertions, 28 deletions
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index 04208c9fe..e2f3d869f 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -491,19 +491,62 @@ struct functor_traits<scalar_atan_op<Scalar> >
};
};
-
/** \internal
* \brief Template functor to compute the tanh of a scalar
* \sa class CwiseUnaryOp, ArrayBase::tanh()
*/
-template<typename Scalar> struct scalar_tanh_op {
+template <typename Scalar>
+struct scalar_tanh_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_op)
- EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { return numext::tanh(a); }
+ EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const {
+ /** \internal \returns the hyperbolic tan of \a a (coeff-wise)
+ Doesn't do anything fancy, just a 13/6-degree rational interpolant
+ which
+ is accurate up to a couple of ulp in the range [-9, 9], outside of
+ which
+ the fl(tanh(x)) = +/-1. */
+
+ // Clamp the inputs to the range [-9, 9] since anything outside
+ // this range is +/-1.0f in single-precision.
+ const Scalar plus_9 = static_cast<Scalar>(9.0);
+ const Scalar minus_9 = static_cast<Scalar>(-9.0);
+ const Scalar x = numext::maxi(minus_9, numext::mini(plus_9, a));
+ // Scalarhe monomial coefficients of the numerator polynomial (odd).
+ const Scalar alpha_1 = static_cast<Scalar>(4.89352455891786e-03);
+ const Scalar alpha_3 = static_cast<Scalar>(6.37261928875436e-04);
+ const Scalar alpha_5 = static_cast<Scalar>(1.48572235717979e-05);
+ const Scalar alpha_7 = static_cast<Scalar>(5.12229709037114e-08);
+ const Scalar alpha_9 = static_cast<Scalar>(-8.60467152213735e-11);
+ const Scalar alpha_11 = static_cast<Scalar>(2.00018790482477e-13);
+ const Scalar alpha_13 = static_cast<Scalar>(-2.76076847742355e-16);
+ // Scalarhe monomial coefficients of the denominator polynomial (even).
+ const Scalar beta_0 = static_cast<Scalar>(4.89352518554385e-03);
+ const Scalar beta_2 = static_cast<Scalar>(2.26843463243900e-03);
+ const Scalar beta_4 = static_cast<Scalar>(1.18534705686654e-04);
+ const Scalar beta_6 = static_cast<Scalar>(1.19825839466702e-06);
+ // Since the polynomials are odd/even, we need x^2.
+ const Scalar x2 = x * x;
+ // Evaluate the numerator polynomial p.
+ Scalar p = x2 * alpha_13 + alpha_11;
+ p = x2 * p + alpha_9;
+ p = x2 * p + alpha_7;
+ p = x2 * p + alpha_5;
+ p = x2 * p + alpha_3;
+ p = x2 * p + alpha_1;
+ p = x * p;
+ // Evaluate the denominator polynomial p.
+ Scalar q = x2 * beta_6 + beta_4;
+ q = x2 * q + beta_2;
+ q = x2 * q + beta_0;
+ // Divide the numerator by the denominator.
+ return p / q;
+ }
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& _x) const {
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
- is accurate up to a couple of ulp in the range [-9, 9], outside of which the
+ is accurate up to a couple of ulp in the range [-9, 9], outside of which
+ the
fl(tanh(x)) = +/-1. */
// Clamp the inputs to the range [-9, 9] since anything outside
@@ -511,7 +554,7 @@ template<typename Scalar> struct scalar_tanh_op {
const Packet plus_9 = pset1<Packet>(9.0);
const Packet minus_9 = pset1<Packet>(-9.0);
const Packet x = pmax(minus_9, pmin(plus_9, _x));
-
+
// The monomial coefficients of the numerator polynomial (odd).
const Packet alpha_1 = pset1<Packet>(4.89352455891786e-03);
const Packet alpha_3 = pset1<Packet>(6.37261928875436e-04);
@@ -520,17 +563,17 @@ template<typename Scalar> struct scalar_tanh_op {
const Packet alpha_9 = pset1<Packet>(-8.60467152213735e-11);
const Packet alpha_11 = pset1<Packet>(2.00018790482477e-13);
const Packet alpha_13 = pset1<Packet>(-2.76076847742355e-16);
-
+
// The monomial coefficients of the denominator polynomial (even).
const Packet beta_0 = pset1<Packet>(4.89352518554385e-03);
const Packet beta_2 = pset1<Packet>(2.26843463243900e-03);
const Packet beta_4 = pset1<Packet>(1.18534705686654e-04);
const Packet beta_6 = pset1<Packet>(1.19825839466702e-06);
-
+
// Since the polynomials are odd/even, we need x^2.
const Packet x2 = pmul(x, x);
-
- // Evaluate the numerator polynomial p.
+
+ // Evaluate the numerator polynomial p.
Packet p = pmadd(x2, alpha_13, alpha_11);
p = pmadd(x2, p, alpha_9);
p = pmadd(x2, p, alpha_7);
@@ -538,38 +581,56 @@ template<typename Scalar> struct scalar_tanh_op {
p = pmadd(x2, p, alpha_3);
p = pmadd(x2, p, alpha_1);
p = pmul(x, p);
-
+
// Evaluate the denominator polynomial p.
Packet q = pmadd(x2, beta_6, beta_4);
q = pmadd(x2, q, beta_2);
q = pmadd(x2, q, beta_0);
-
+
// Divide the numerator by the denominator.
return pdiv(p, q);
}
};
-template<typename Scalar>
-struct functor_traits<scalar_tanh_op<Scalar> >
-{
+template <>
+struct scalar_tanh_op<std::complex<double> > {
+ EIGEN_DEVICE_FUNC inline const std::complex<double> operator()(
+ const std::complex<double>& a) const {
+ return numext::tanh(a);
+ }
+};
+template <>
+struct scalar_tanh_op<std::complex<float> > {
+ EIGEN_DEVICE_FUNC inline const std::complex<float> operator()(
+ const std::complex<float>& a) const {
+ return numext::tanh(a);
+ }
+};
+template <typename Scalar>
+struct functor_traits<scalar_tanh_op<Scalar> > {
enum {
PacketAccess = packet_traits<Scalar>::HasTanh,
- Cost =
- (PacketAccess
- // The following numbers are based on the AVX implementation,
+ Cost = (PacketAccess && (!is_same<Scalar, std::complex<float> >::value) &&
+ (!is_same<Scalar, std::complex<double> >::value)
+// The following numbers are based on the AVX implementation,
#ifdef EIGEN_VECTORIZE_FMA
- // Haswell can issue 2 add/mul/madd per cycle.
- // 9 pmadd, 2 pmul, 1 div, 2 other
- ? (2 * NumTraits<Scalar>::AddCost + 6 * NumTraits<Scalar>::MulCost +
- NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost)
+ // Haswell can issue 2 add/mul/madd per cycle.
+ // 9 pmadd, 2 pmul, 1 div, 2 other
+ ? (2 * NumTraits<Scalar>::AddCost +
+ 6 * NumTraits<Scalar>::MulCost +
+ NumTraits<Scalar>::template Div<
+ packet_traits<Scalar>::HasDiv>::Cost)
#else
- ? (11 * NumTraits<Scalar>::AddCost +
- 11 * NumTraits<Scalar>::MulCost +
- NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost)
+ ? (11 * NumTraits<Scalar>::AddCost +
+ 11 * NumTraits<Scalar>::MulCost +
+ NumTraits<Scalar>::template Div<
+ packet_traits<Scalar>::HasDiv>::Cost)
#endif
- // This number assumes a naive implementation of tanh
- : (6 * NumTraits<Scalar>::AddCost + 3 * NumTraits<Scalar>::MulCost +
- 2 * NumTraits<Scalar>::template Div<packet_traits<Scalar>::HasDiv>::Cost +
- functor_traits<scalar_exp_op<Scalar> >::Cost))
+ // This number assumes a naive implementation of tanh
+ : (6 * NumTraits<Scalar>::AddCost +
+ 3 * NumTraits<Scalar>::MulCost +
+ 2 * NumTraits<Scalar>::template Div<
+ packet_traits<Scalar>::HasDiv>::Cost +
+ functor_traits<scalar_exp_op<Scalar> >::Cost))
};
};