aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-07-07 01:54:04 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-07-07 01:54:04 +0000
commit6964ae8d52d42d2821572fc8359e56c821289e00 (patch)
tree18eee0bf5a4d10d026420a540879d908e9bb9fc6
parentcb6315318316653479b184db447bc29040be8e6e (diff)
Change the sign operator in Eigen to return NaN for NaN arguments, not zero.
-rw-r--r--Eigen/src/Core/functors/UnaryFunctors.h19
-rw-r--r--Eigen/src/Core/util/ForwardDeclarations.h2
-rw-r--r--test/array_cwise.cpp2
3 files changed, 18 insertions, 5 deletions
diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h
index 735d28ac8..a9cbfd967 100644
--- a/Eigen/src/Core/functors/UnaryFunctors.h
+++ b/Eigen/src/Core/functors/UnaryFunctors.h
@@ -904,9 +904,9 @@ struct functor_traits<scalar_boolean_not_op<Scalar> > {
* \brief Template functor to compute the signum of a scalar
* \sa class CwiseUnaryOp, Cwise::sign()
*/
-template<typename Scalar,bool iscpx=(NumTraits<Scalar>::IsComplex!=0) > struct scalar_sign_op;
+template<typename Scalar,bool is_complex=(NumTraits<Scalar>::IsComplex!=0), bool is_integer=(NumTraits<Scalar>::IsInteger!=0) > struct scalar_sign_op;
template<typename Scalar>
-struct scalar_sign_op<Scalar,false> {
+struct scalar_sign_op<Scalar, false, true> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
@@ -916,8 +916,21 @@ struct scalar_sign_op<Scalar,false> {
//template <typename Packet>
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
};
+
template<typename Scalar>
-struct scalar_sign_op<Scalar,true> {
+struct scalar_sign_op<Scalar, false, false> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
+ EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
+ {
+ return (numext::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
+ }
+ //TODO
+ //template <typename Packet>
+ //EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
+};
+
+template<typename Scalar, bool is_integer>
+struct scalar_sign_op<Scalar,true, is_integer> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h
index 031a8ec3c..208b96c9c 100644
--- a/Eigen/src/Core/util/ForwardDeclarations.h
+++ b/Eigen/src/Core/util/ForwardDeclarations.h
@@ -205,7 +205,7 @@ template<typename Scalar, typename NewType> struct scalar_cast_op;
template<typename Scalar> struct scalar_random_op;
template<typename Scalar> struct scalar_constant_op;
template<typename Scalar> struct scalar_identity_op;
-template<typename Scalar,bool iscpx> struct scalar_sign_op;
+template<typename Scalar,bool is_complex, bool is_integer> struct scalar_sign_op;
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index b3fb59bc8..48c0935a4 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -309,7 +309,7 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m1.cube(), cube(m1));
VERIFY_IS_APPROX(cos(m1+RealScalar(3)*m2), cos((m1+RealScalar(3)*m2).eval()));
VERIFY_IS_APPROX(m1.sign(), sign(m1));
-
+ VERIFY((m1.sqrt().sign().isNaN() == (Eigen::isnan)(sign(sqrt(m1)))).all());
// avoid NaNs with abs() so verification doesn't fail
m3 = m1.abs();