aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-07 19:05:18 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-07 19:05:18 +0000
commitb43102440489df9d0175c88e602dfa425b574a94 (patch)
tree9325c3401de7047451d4a59ad343cdf1c5a83679 /unsupported
parentf66f3393e3d567e5c8b138fbad69b316214a4ce9 (diff)
Don't make assumptions about NaN-propagation for pmin/pmax - it various across platforms.
Change test to only test for NaN-propagation for pfmin/pfmax.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h22
-rw-r--r--unsupported/test/cxx11_tensor_expr.cpp95
2 files changed, 79 insertions, 38 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index bb0969f49..ef332dd19 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -395,16 +395,18 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_mod_op<Scalar>(rhs));
}
+ template <int NanPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
cwiseMax(Scalar threshold) const {
- return cwiseMax(constant(threshold));
+ return cwiseMax<NanPropagation>(constant(threshold));
}
+ template <int NanPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
cwiseMin(Scalar threshold) const {
- return cwiseMin(constant(threshold));
+ return cwiseMin<NanPropagation>(constant(threshold));
}
template<typename NewType>
@@ -472,16 +474,16 @@ class TensorBase<Derived, ReadOnlyAccessors>
return binaryExpr(other.derived(), internal::scalar_quotient_op<Scalar>());
}
- template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>
+ template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived>
cwiseMax(const OtherDerived& other) const {
- return binaryExpr(other.derived(), internal::scalar_max_op<Scalar>());
+ return binaryExpr(other.derived(), internal::scalar_max_op<Scalar,Scalar, NaNPropagation>());
}
- template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>
+ template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived>
cwiseMin(const OtherDerived& other) const {
- return binaryExpr(other.derived(), internal::scalar_min_op<Scalar>());
+ return binaryExpr(other.derived(), internal::scalar_min_op<Scalar,Scalar, NaNPropagation>());
}
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp
index b49663fe9..7fac3b4ed 100644
--- a/unsupported/test/cxx11_tensor_expr.cpp
+++ b/unsupported/test/cxx11_tensor_expr.cpp
@@ -303,40 +303,79 @@ template <typename Scalar>
void test_minmax_nan_propagation_templ() {
for (int size = 1; size < 17; ++size) {
const Scalar kNan = std::numeric_limits<Scalar>::quiet_NaN();
+ const Scalar kZero(0);
Tensor<Scalar, 1> vec_nan(size);
Tensor<Scalar, 1> vec_zero(size);
- Tensor<Scalar, 1> vec_res(size);
vec_nan.setConstant(kNan);
vec_zero.setZero();
- vec_res.setZero();
-
- // Test that we propagate NaNs in the tensor when applying the
- // cwiseMax(scalar) operator, which is used for the Relu operator.
- vec_res = vec_nan.cwiseMax(Scalar(0));
- for (int i = 0; i < size; ++i) {
- VERIFY((numext::isnan)(vec_res(i)));
- }
-
- // Test that NaNs do not propagate if we reverse the arguments.
- vec_res = vec_zero.cwiseMax(kNan);
- for (int i = 0; i < size; ++i) {
- VERIFY_IS_EQUAL(vec_res(i), Scalar(0));
- }
-
- // Test that we propagate NaNs in the tensor when applying the
- // cwiseMin(scalar) operator.
- vec_res.setZero();
- vec_res = vec_nan.cwiseMin(Scalar(0));
- for (int i = 0; i < size; ++i) {
- VERIFY((numext::isnan)(vec_res(i)));
- }
+ auto verify_all_nan = [&](const Tensor<Scalar, 1>& v) {
+ for (int i = 0; i < size; ++i) {
+ VERIFY((numext::isnan)(v(i)));
+ }
+ };
- // Test that NaNs do not propagate if we reverse the arguments.
- vec_res = vec_zero.cwiseMin(kNan);
- for (int i = 0; i < size; ++i) {
- VERIFY_IS_EQUAL(vec_res(i), Scalar(0));
- }
+ auto verify_all_zero = [&](const Tensor<Scalar, 1>& v) {
+ for (int i = 0; i < size; ++i) {
+ VERIFY_IS_EQUAL(v(i), Scalar(0));
+ }
+ };
+
+ // Test NaN propagating max.
+ // max(nan, nan) = nan
+ // max(nan, 0) = nan
+ // max(0, nan) = nan
+ // max(0, 0) = 0
+ verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kNan));
+ verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_nan));
+ verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kZero));
+ verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_zero));
+ verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNan));
+ verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_nan));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(kZero));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(vec_zero));
+
+ // Test number propagating max.
+ // max(nan, nan) = nan
+ // max(nan, 0) = 0
+ // max(0, nan) = 0
+ // max(0, 0) = 0
+ verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(kNan));
+ verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(vec_nan));
+ verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(kZero));
+ verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(vec_zero));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNan));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_nan));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kZero));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_zero));
+
+ // Test NaN propagating min.
+ // min(nan, nan) = nan
+ // min(nan, 0) = nan
+ // min(0, nan) = nan
+ // min(0, 0) = 0
+ verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kNan));
+ verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_nan));
+ verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kZero));
+ verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_zero));
+ verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNan));
+ verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_nan));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(kZero));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(vec_zero));
+
+ // Test number propagating min.
+ // min(nan, nan) = nan
+ // min(nan, 0) = 0
+ // min(0, nan) = 0
+ // min(0, 0) = 0
+ verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(kNan));
+ verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(vec_nan));
+ verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(kZero));
+ verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(vec_zero));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNan));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_nan));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kZero));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_zero));
}
}