aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-24 17:49:20 -0800
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2021-02-25 18:21:21 +0000
commit5297b7162a0630e2e5b1459fa665c9f3b1eb532a (patch)
tree5cc09d97a66f3a21b65be9bb1b58a5ff1f7035ed
parentecb7b19dfa6c4bbf7a4068e114a1c86aa88908fe (diff)
Make it possible to specify NaN propagation strategy for maxCoeff/minCoeff reductions.
-rw-r--r--Eigen/src/Core/DenseBase.h5
-rw-r--r--Eigen/src/Core/Redux.h30
-rw-r--r--test/array_cwise.cpp14
3 files changed, 49 insertions, 0 deletions
diff --git a/Eigen/src/Core/DenseBase.h b/Eigen/src/Core/DenseBase.h
index 767a8e274..c83a3fcc6 100644
--- a/Eigen/src/Core/DenseBase.h
+++ b/Eigen/src/Core/DenseBase.h
@@ -452,6 +452,11 @@ template<typename Derived> class DenseBase
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar minCoeff() const;
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar maxCoeff() const;
+ template<int NaNPropagation>
+ EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar minCoeff() const;
+ template<int NaNPropagation>
+ EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar maxCoeff() const;
+
template<typename IndexType> EIGEN_DEVICE_FUNC
typename internal::traits<Derived>::Scalar minCoeff(IndexType* row, IndexType* col) const;
template<typename IndexType> EIGEN_DEVICE_FUNC
diff --git a/Eigen/src/Core/Redux.h b/Eigen/src/Core/Redux.h
index 2eef5abc5..4e5affc43 100644
--- a/Eigen/src/Core/Redux.h
+++ b/Eigen/src/Core/Redux.h
@@ -429,6 +429,21 @@ DenseBase<Derived>::minCoeff() const
return derived().redux(Eigen::internal::scalar_min_op<Scalar,Scalar>());
}
+/** \returns the minimum of all coefficients of \c *this.
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is minimum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
+ */
+template<typename Derived>
+template<int NaNPropagation>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+DenseBase<Derived>::minCoeff() const
+{
+ return derived().redux(Eigen::internal::scalar_min_op<Scalar,Scalar, NaNPropagation>());
+}
+
/** \returns the maximum of all coefficients of \c *this.
* \warning the matrix must be not empty, otherwise an assertion is triggered.
* \warning the result is undefined if \c *this contains NaN.
@@ -440,6 +455,21 @@ DenseBase<Derived>::maxCoeff() const
return derived().redux(Eigen::internal::scalar_max_op<Scalar,Scalar>());
}
+/** \returns the maximum of all coefficients of \c *this.
+ * In case \c *this contains NaN, NaNPropagation determines the behavior:
+ * NaNPropagation == PropagateFast : undefined
+ * NaNPropagation == PropagateNaN : result is NaN
+ * NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
+ * \warning the matrix must be not empty, otherwise an assertion is triggered.
+ */
+template<typename Derived>
+template<int NaNPropagation>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename internal::traits<Derived>::Scalar
+DenseBase<Derived>::maxCoeff() const
+{
+ return derived().redux(Eigen::internal::scalar_max_op<Scalar,Scalar, NaNPropagation>());
+}
+
/** \returns the sum of all coefficients of \c *this
*
* If \c *this is empty, then the value 0 is returned.
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index 7f7e44f89..92abf6968 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -610,6 +610,20 @@ template<typename ArrayType> void min_max(const ArrayType& m)
VERIFY_IS_APPROX(ArrayType::Constant(rows,cols, maxM1), (m1.max)( maxM1));
VERIFY_IS_APPROX(m1, (m1.max)( minM1));
+
+ // min/max with various NaN propagation options.
+ if (m1.size() > 1 && !NumTraits<Scalar>::IsInteger) {
+ m1(0,0) = std::numeric_limits<Scalar>::quiet_NaN();
+ maxM1 = m1.template maxCoeff<PropagateNaN>();
+ minM1 = m1.template minCoeff<PropagateNaN>();
+ VERIFY((numext::isnan)(maxM1));
+ VERIFY((numext::isnan)(minM1));
+
+ maxM1 = m1.template maxCoeff<PropagateNumbers>();
+ minM1 = m1.template minCoeff<PropagateNumbers>();
+ VERIFY(!(numext::isnan)(maxM1));
+ VERIFY(!(numext::isnan)(minM1));
+ }
}
EIGEN_DECLARE_TEST(array_cwise)