diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-07 19:05:18 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-07 19:05:18 +0000 |
commit | b43102440489df9d0175c88e602dfa425b574a94 (patch) | |
tree | 9325c3401de7047451d4a59ad343cdf1c5a83679 /unsupported | |
parent | f66f3393e3d567e5c8b138fbad69b316214a4ce9 (diff) |
Don't make assumptions about NaN-propagation for pmin/pmax - it various across platforms.
Change test to only test for NaN-propagation for pfmin/pfmax.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 22 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_expr.cpp | 95 |
2 files changed, 79 insertions, 38 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index bb0969f49..ef332dd19 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -395,16 +395,18 @@ class TensorBase<Derived, ReadOnlyAccessors> return unaryExpr(internal::scalar_mod_op<Scalar>(rhs)); } + template <int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > cwiseMax(Scalar threshold) const { - return cwiseMax(constant(threshold)); + return cwiseMax<NanPropagation>(constant(threshold)); } + template <int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > + EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> > cwiseMin(Scalar threshold) const { - return cwiseMin(constant(threshold)); + return cwiseMin<NanPropagation>(constant(threshold)); } template<typename NewType> @@ -472,16 +474,16 @@ class TensorBase<Derived, ReadOnlyAccessors> return binaryExpr(other.derived(), internal::scalar_quotient_op<Scalar>()); } - template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived> + template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived> cwiseMax(const OtherDerived& other) const { - return binaryExpr(other.derived(), internal::scalar_max_op<Scalar>()); + return binaryExpr(other.derived(), internal::scalar_max_op<Scalar,Scalar, NaNPropagation>()); } - template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived> + template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived> cwiseMin(const OtherDerived& other) const { - return binaryExpr(other.derived(), internal::scalar_min_op<Scalar>()); + return binaryExpr(other.derived(), internal::scalar_min_op<Scalar,Scalar, NaNPropagation>()); } template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp index b49663fe9..7fac3b4ed 100644 --- a/unsupported/test/cxx11_tensor_expr.cpp +++ b/unsupported/test/cxx11_tensor_expr.cpp @@ -303,40 +303,79 @@ template <typename Scalar> void test_minmax_nan_propagation_templ() { for (int size = 1; size < 17; ++size) { const Scalar kNan = std::numeric_limits<Scalar>::quiet_NaN(); + const Scalar kZero(0); Tensor<Scalar, 1> vec_nan(size); Tensor<Scalar, 1> vec_zero(size); - Tensor<Scalar, 1> vec_res(size); vec_nan.setConstant(kNan); vec_zero.setZero(); - vec_res.setZero(); - - // Test that we propagate NaNs in the tensor when applying the - // cwiseMax(scalar) operator, which is used for the Relu operator. - vec_res = vec_nan.cwiseMax(Scalar(0)); - for (int i = 0; i < size; ++i) { - VERIFY((numext::isnan)(vec_res(i))); - } - - // Test that NaNs do not propagate if we reverse the arguments. - vec_res = vec_zero.cwiseMax(kNan); - for (int i = 0; i < size; ++i) { - VERIFY_IS_EQUAL(vec_res(i), Scalar(0)); - } - - // Test that we propagate NaNs in the tensor when applying the - // cwiseMin(scalar) operator. - vec_res.setZero(); - vec_res = vec_nan.cwiseMin(Scalar(0)); - for (int i = 0; i < size; ++i) { - VERIFY((numext::isnan)(vec_res(i))); - } + auto verify_all_nan = [&](const Tensor<Scalar, 1>& v) { + for (int i = 0; i < size; ++i) { + VERIFY((numext::isnan)(v(i))); + } + }; - // Test that NaNs do not propagate if we reverse the arguments. - vec_res = vec_zero.cwiseMin(kNan); - for (int i = 0; i < size; ++i) { - VERIFY_IS_EQUAL(vec_res(i), Scalar(0)); - } + auto verify_all_zero = [&](const Tensor<Scalar, 1>& v) { + for (int i = 0; i < size; ++i) { + VERIFY_IS_EQUAL(v(i), Scalar(0)); + } + }; + + // Test NaN propagating max. + // max(nan, nan) = nan + // max(nan, 0) = nan + // max(0, nan) = nan + // max(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kNan)); + verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_nan)); + verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kZero)); + verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_zero)); + verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNan)); + verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_nan)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(kZero)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(vec_zero)); + + // Test number propagating max. + // max(nan, nan) = nan + // max(nan, 0) = 0 + // max(0, nan) = 0 + // max(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(kNan)); + verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(vec_nan)); + verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(kZero)); + verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(vec_zero)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNan)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_nan)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kZero)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_zero)); + + // Test NaN propagating min. + // min(nan, nan) = nan + // min(nan, 0) = nan + // min(0, nan) = nan + // min(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kNan)); + verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_nan)); + verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kZero)); + verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_zero)); + verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNan)); + verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_nan)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(kZero)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(vec_zero)); + + // Test number propagating min. + // min(nan, nan) = nan + // min(nan, 0) = 0 + // min(0, nan) = 0 + // min(0, 0) = 0 + verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(kNan)); + verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(vec_nan)); + verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(kZero)); + verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(vec_zero)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNan)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_nan)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kZero)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_zero)); } } |