From c6953f799b01d36f4236b64f351cc1446e0abe17 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Tue, 13 Oct 2020 21:48:31 +0000 Subject: Add packet generic ops `predux_fmin`, `predux_fmin_nan`, `predux_fmax`, and `predux_fmax_nan` that implement reductions with `PropagateNaN`, and `PropagateNumbers` semantics. Add (slow) generic implementations for most reductions. --- unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 22 ++--- .../Eigen/CXX11/src/Tensor/TensorFunctors.h | 41 +++++----- unsupported/test/cxx11_tensor_expr.cpp | 94 +++++++++++++++------- 3 files changed, 101 insertions(+), 56 deletions(-) (limited to 'unsupported') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index ef332dd19..3a70d8517 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -682,28 +682,30 @@ class TensorBase return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::ProdReducer()); } - template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp, const Dims, const Derived> + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp, const Dims, const Derived> maximum(const Dims& dims) const { - return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MaxReducer()); + return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MaxReducer()); } - const TensorReductionOp, const DimensionList, const Derived> + template + const TensorReductionOp, const DimensionList, const Derived> maximum() const { DimensionList in_dims; - return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MaxReducer()); + return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MaxReducer()); } - template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - const TensorReductionOp, const Dims, const Derived> + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorReductionOp, const Dims, const Derived> minimum(const Dims& dims) const { - return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MinReducer()); + return TensorReductionOp, const Dims, const Derived>(derived(), dims, internal::MinReducer()); } - const TensorReductionOp, const DimensionList, const Derived> + template + const TensorReductionOp, const DimensionList, const Derived> minimum() const { DimensionList in_dims; - return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MinReducer()); + return TensorReductionOp, const DimensionList, const Derived>(derived(), in_dims, internal::MinReducer()); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index 2edc45f1a..fd8fa00fa 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -192,17 +192,19 @@ struct MinMaxBottomValue { }; -template struct MaxReducer +template struct MaxReducer { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { - if (t > *accum) { *accum = t; } + scalar_max_op op; + *accum = op(t, *accum); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const { - (*accum) = pmax(*accum, p); + scalar_max_op op; + (*accum) = op.packetOp(*accum, p); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return MinMaxBottomValue::IsInteger>::bottom_value(); + return MinMaxBottomValue::IsInteger>::bottom_value(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { @@ -217,32 +219,34 @@ template struct MaxReducer } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const { - return numext::maxi(saccum, predux_max(vaccum)); + scalar_max_op op; + return op(saccum, op.predux(vaccum)); } }; -template -struct reducer_traits, Device> { +template + struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, PacketAccess = PacketType::HasMax, IsStateful = false, - IsExactlyAssociative = true + IsExactlyAssociative = (NaNPropagation!=PropagateFast) }; }; - -template struct MinReducer +template struct MinReducer { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { - if (t < *accum) { *accum = t; } + scalar_min_op op; + *accum = op(t, *accum); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const { - (*accum) = pmin(*accum, p); + scalar_min_op op; + (*accum) = op.packetOp(*accum, p); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return MinMaxBottomValue::IsInteger>::bottom_value(); + return MinMaxBottomValue::IsInteger>::bottom_value(); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { @@ -257,21 +261,21 @@ template struct MinReducer } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const { - return numext::mini(saccum, predux_min(vaccum)); + scalar_min_op op; + return op(saccum, op.predux(vaccum)); } }; -template -struct reducer_traits, Device> { +template + struct reducer_traits, Device> { enum { Cost = NumTraits::AddCost, PacketAccess = PacketType::HasMin, IsStateful = false, - IsExactlyAssociative = true + IsExactlyAssociative = (NaNPropagation!=PropagateFast) }; }; - template struct ProdReducer { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { @@ -282,7 +286,6 @@ template struct ProdReducer EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const { (*accum) = pmul(*accum, p); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { internal::scalar_cast_op conv; return conv(1); diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp index 7fac3b4ed..556d01d4d 100644 --- a/unsupported/test/cxx11_tensor_expr.cpp +++ b/unsupported/test/cxx11_tensor_expr.cpp @@ -302,12 +302,17 @@ static void test_select() template void test_minmax_nan_propagation_templ() { for (int size = 1; size < 17; ++size) { - const Scalar kNan = std::numeric_limits::quiet_NaN(); + std::cout << "size = " << size << std::endl; + const Scalar kNaN = std::numeric_limits::quiet_NaN(); + const Scalar kInf = std::numeric_limits::infinity(); const Scalar kZero(0); - Tensor vec_nan(size); + Tensor vec_all_nan(size); + Tensor vec_one_nan(size); Tensor vec_zero(size); - vec_nan.setConstant(kNan); + vec_all_nan.setConstant(kNaN); vec_zero.setZero(); + vec_one_nan.setZero(); + vec_one_nan(size/2) = kNaN; auto verify_all_nan = [&](const Tensor& v) { for (int i = 0; i < size; ++i) { @@ -326,12 +331,12 @@ void test_minmax_nan_propagation_templ() { // max(nan, 0) = nan // max(0, nan) = nan // max(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMax(kNan)); - verify_all_nan(vec_nan.template cwiseMax(vec_nan)); - verify_all_nan(vec_nan.template cwiseMax(kZero)); - verify_all_nan(vec_nan.template cwiseMax(vec_zero)); - verify_all_nan(vec_zero.template cwiseMax(kNan)); - verify_all_nan(vec_zero.template cwiseMax(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMax(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMax(vec_all_nan)); + verify_all_nan(vec_all_nan.template cwiseMax(kZero)); + verify_all_nan(vec_all_nan.template cwiseMax(vec_zero)); + verify_all_nan(vec_zero.template cwiseMax(kNaN)); + verify_all_nan(vec_zero.template cwiseMax(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMax(kZero)); verify_all_zero(vec_zero.template cwiseMax(vec_zero)); @@ -340,12 +345,12 @@ void test_minmax_nan_propagation_templ() { // max(nan, 0) = 0 // max(0, nan) = 0 // max(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMax(kNan)); - verify_all_nan(vec_nan.template cwiseMax(vec_nan)); - verify_all_zero(vec_nan.template cwiseMax(kZero)); - verify_all_zero(vec_nan.template cwiseMax(vec_zero)); - verify_all_zero(vec_zero.template cwiseMax(kNan)); - verify_all_zero(vec_zero.template cwiseMax(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMax(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMax(vec_all_nan)); + verify_all_zero(vec_all_nan.template cwiseMax(kZero)); + verify_all_zero(vec_all_nan.template cwiseMax(vec_zero)); + verify_all_zero(vec_zero.template cwiseMax(kNaN)); + verify_all_zero(vec_zero.template cwiseMax(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMax(kZero)); verify_all_zero(vec_zero.template cwiseMax(vec_zero)); @@ -354,12 +359,12 @@ void test_minmax_nan_propagation_templ() { // min(nan, 0) = nan // min(0, nan) = nan // min(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMin(kNan)); - verify_all_nan(vec_nan.template cwiseMin(vec_nan)); - verify_all_nan(vec_nan.template cwiseMin(kZero)); - verify_all_nan(vec_nan.template cwiseMin(vec_zero)); - verify_all_nan(vec_zero.template cwiseMin(kNan)); - verify_all_nan(vec_zero.template cwiseMin(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMin(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMin(vec_all_nan)); + verify_all_nan(vec_all_nan.template cwiseMin(kZero)); + verify_all_nan(vec_all_nan.template cwiseMin(vec_zero)); + verify_all_nan(vec_zero.template cwiseMin(kNaN)); + verify_all_nan(vec_zero.template cwiseMin(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMin(kZero)); verify_all_zero(vec_zero.template cwiseMin(vec_zero)); @@ -368,14 +373,49 @@ void test_minmax_nan_propagation_templ() { // min(nan, 0) = 0 // min(0, nan) = 0 // min(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMin(kNan)); - verify_all_nan(vec_nan.template cwiseMin(vec_nan)); - verify_all_zero(vec_nan.template cwiseMin(kZero)); - verify_all_zero(vec_nan.template cwiseMin(vec_zero)); - verify_all_zero(vec_zero.template cwiseMin(kNan)); - verify_all_zero(vec_zero.template cwiseMin(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMin(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMin(vec_all_nan)); + verify_all_zero(vec_all_nan.template cwiseMin(kZero)); + verify_all_zero(vec_all_nan.template cwiseMin(vec_zero)); + verify_all_zero(vec_zero.template cwiseMin(kNaN)); + verify_all_zero(vec_zero.template cwiseMin(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMin(kZero)); verify_all_zero(vec_zero.template cwiseMin(vec_zero)); + + // Test min and max reduction + Tensor val; + val = vec_zero.minimum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template minimum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template minimum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.maximum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template maximum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template maximum(); + VERIFY_IS_EQUAL(val(), kZero); + + // Test NaN propagation for tensor of all NaNs. + val = vec_all_nan.template minimum(); + VERIFY((numext::isnan)(val())); + val = vec_all_nan.template minimum(); + VERIFY_IS_EQUAL(val(), kInf); + val = vec_all_nan.template maximum(); + VERIFY((numext::isnan)(val())); + val = vec_all_nan.template maximum(); + VERIFY_IS_EQUAL(val(), -kInf); + + // Test NaN propagation for tensor with a single NaN. + val = vec_one_nan.template minimum(); + VERIFY((numext::isnan)(val())); + val = vec_one_nan.template minimum(); + VERIFY_IS_EQUAL(val(), (size == 1 ? kInf : kZero)); + val = vec_one_nan.template maximum(); + VERIFY((numext::isnan)(val())); + val = vec_one_nan.template maximum(); + VERIFY_IS_EQUAL(val(), (size == 1 ? -kInf : kZero)); } } -- cgit v1.2.3