aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-13 21:48:31 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-13 21:48:31 +0000
commitc6953f799b01d36f4236b64f351cc1446e0abe17 (patch)
tree9abcded97c6effc010d08787c5b43ef7bb043b54 /unsupported
parent807e51528d220c0efed870f0505dea81a5776085 (diff)
Add packet generic ops `predux_fmin`, `predux_fmin_nan`, `predux_fmax`, and `predux_fmax_nan` that implement reductions with `PropagateNaN`, and `PropagateNumbers` semantics. Add (slow) generic implementations for most reductions.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h22
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h41
-rw-r--r--unsupported/test/cxx11_tensor_expr.cpp94
3 files changed, 101 insertions, 56 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index ef332dd19..3a70d8517 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -682,28 +682,30 @@ class TensorBase<Derived, ReadOnlyAccessors>
return TensorReductionOp<internal::ProdReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::ProdReducer<CoeffReturnType>());
}
- template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>
+ template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>
maximum(const Dims& dims) const {
- return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MaxReducer<CoeffReturnType,NanPropagation>());
}
- const TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>
+ template <int NanPropagation=PropagateFast>
+ const TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>
maximum() const {
DimensionList<Index, NumDimensions> in_dims;
- return TensorReductionOp<internal::MaxReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MaxReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MaxReducer<CoeffReturnType,NanPropagation>());
}
- template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>
+ template <typename Dims,int NanPropagation=PropagateFast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>
minimum(const Dims& dims) const {
- return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const Dims, const Derived>(derived(), dims, internal::MinReducer<CoeffReturnType,NanPropagation>());
}
- const TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>
+ template <int NanPropagation=PropagateFast>
+ const TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>
minimum() const {
DimensionList<Index, NumDimensions> in_dims;
- return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>());
+ return TensorReductionOp<internal::MinReducer<CoeffReturnType,NanPropagation>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType,NanPropagation>());
}
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index 2edc45f1a..fd8fa00fa 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -192,17 +192,19 @@ struct MinMaxBottomValue<T, false, false> {
};
-template <typename T> struct MaxReducer
+template <typename T, int NaNPropagation=PropagateFast> struct MaxReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
- if (t > *accum) { *accum = t; }
+ scalar_max_op<T, T, NaNPropagation> op;
+ *accum = op(t, *accum);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
- (*accum) = pmax<Packet>(*accum, p);
+ scalar_max_op<T, T, NaNPropagation> op;
+ (*accum) = op.packetOp(*accum, p);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
- return MinMaxBottomValue<T, true, Eigen::NumTraits<T>::IsInteger>::bottom_value();
+ return MinMaxBottomValue<T, /*IsMax=*/true, Eigen::NumTraits<T>::IsInteger>::bottom_value();
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
@@ -217,32 +219,34 @@ template <typename T> struct MaxReducer
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
- return numext::maxi(saccum, predux_max(vaccum));
+ scalar_max_op<T, T, NaNPropagation> op;
+ return op(saccum, op.predux(vaccum));
}
};
-template <typename T, typename Device>
-struct reducer_traits<MaxReducer<T>, Device> {
+template <typename T, typename Device, int NaNPropagation>
+ struct reducer_traits<MaxReducer<T, NaNPropagation>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
PacketAccess = PacketType<T, Device>::HasMax,
IsStateful = false,
- IsExactlyAssociative = true
+ IsExactlyAssociative = (NaNPropagation!=PropagateFast)
};
};
-
-template <typename T> struct MinReducer
+template <typename T, int NaNPropagation=PropagateFast> struct MinReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
- if (t < *accum) { *accum = t; }
+ scalar_min_op<T, T, NaNPropagation> op;
+ *accum = op(t, *accum);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
- (*accum) = pmin<Packet>(*accum, p);
+ scalar_min_op<T, T, NaNPropagation> op;
+ (*accum) = op.packetOp(*accum, p);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
- return MinMaxBottomValue<T, false, Eigen::NumTraits<T>::IsInteger>::bottom_value();
+ return MinMaxBottomValue<T, /*IsMax=*/false, Eigen::NumTraits<T>::IsInteger>::bottom_value();
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
@@ -257,21 +261,21 @@ template <typename T> struct MinReducer
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
- return numext::mini(saccum, predux_min(vaccum));
+ scalar_min_op<T, T, NaNPropagation> op;
+ return op(saccum, op.predux(vaccum));
}
};
-template <typename T, typename Device>
-struct reducer_traits<MinReducer<T>, Device> {
+template <typename T, typename Device, int NaNPropagation>
+ struct reducer_traits<MinReducer<T, NaNPropagation>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
PacketAccess = PacketType<T, Device>::HasMin,
IsStateful = false,
- IsExactlyAssociative = true
+ IsExactlyAssociative = (NaNPropagation!=PropagateFast)
};
};
-
template <typename T> struct ProdReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
@@ -282,7 +286,6 @@ template <typename T> struct ProdReducer
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
(*accum) = pmul<Packet>(*accum, p);
}
-
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
internal::scalar_cast_op<int, T> conv;
return conv(1);
diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp
index 7fac3b4ed..556d01d4d 100644
--- a/unsupported/test/cxx11_tensor_expr.cpp
+++ b/unsupported/test/cxx11_tensor_expr.cpp
@@ -302,12 +302,17 @@ static void test_select()
template <typename Scalar>
void test_minmax_nan_propagation_templ() {
for (int size = 1; size < 17; ++size) {
- const Scalar kNan = std::numeric_limits<Scalar>::quiet_NaN();
+ std::cout << "size = " << size << std::endl;
+ const Scalar kNaN = std::numeric_limits<Scalar>::quiet_NaN();
+ const Scalar kInf = std::numeric_limits<Scalar>::infinity();
const Scalar kZero(0);
- Tensor<Scalar, 1> vec_nan(size);
+ Tensor<Scalar, 1> vec_all_nan(size);
+ Tensor<Scalar, 1> vec_one_nan(size);
Tensor<Scalar, 1> vec_zero(size);
- vec_nan.setConstant(kNan);
+ vec_all_nan.setConstant(kNaN);
vec_zero.setZero();
+ vec_one_nan.setZero();
+ vec_one_nan(size/2) = kNaN;
auto verify_all_nan = [&](const Tensor<Scalar, 1>& v) {
for (int i = 0; i < size; ++i) {
@@ -326,12 +331,12 @@ void test_minmax_nan_propagation_templ() {
// max(nan, 0) = nan
// max(0, nan) = nan
// max(0, 0) = 0
- verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kNan));
- verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_nan));
- verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kZero));
- verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_zero));
- verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNan));
- verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_nan));
+ verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kNaN));
+ verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_all_nan));
+ verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kZero));
+ verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_zero));
+ verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNaN));
+ verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(kZero));
verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(vec_zero));
@@ -340,12 +345,12 @@ void test_minmax_nan_propagation_templ() {
// max(nan, 0) = 0
// max(0, nan) = 0
// max(0, 0) = 0
- verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(kNan));
- verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(vec_nan));
- verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(kZero));
- verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(vec_zero));
- verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNan));
- verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_nan));
+ verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(kNaN));
+ verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_all_nan));
+ verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(kZero));
+ verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_zero));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNaN));
+ verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kZero));
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_zero));
@@ -354,12 +359,12 @@ void test_minmax_nan_propagation_templ() {
// min(nan, 0) = nan
// min(0, nan) = nan
// min(0, 0) = 0
- verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kNan));
- verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_nan));
- verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kZero));
- verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_zero));
- verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNan));
- verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_nan));
+ verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kNaN));
+ verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_all_nan));
+ verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kZero));
+ verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_zero));
+ verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNaN));
+ verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(kZero));
verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(vec_zero));
@@ -368,14 +373,49 @@ void test_minmax_nan_propagation_templ() {
// min(nan, 0) = 0
// min(0, nan) = 0
// min(0, 0) = 0
- verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(kNan));
- verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(vec_nan));
- verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(kZero));
- verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(vec_zero));
- verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNan));
- verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_nan));
+ verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(kNaN));
+ verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_all_nan));
+ verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(kZero));
+ verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_zero));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNaN));
+ verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_all_nan));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kZero));
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_zero));
+
+ // Test min and max reduction
+ Tensor<Scalar, 0> val;
+ val = vec_zero.minimum();
+ VERIFY_IS_EQUAL(val(), kZero);
+ val = vec_zero.template minimum<PropagateNaN>();
+ VERIFY_IS_EQUAL(val(), kZero);
+ val = vec_zero.template minimum<PropagateNumbers>();
+ VERIFY_IS_EQUAL(val(), kZero);
+ val = vec_zero.maximum();
+ VERIFY_IS_EQUAL(val(), kZero);
+ val = vec_zero.template maximum<PropagateNaN>();
+ VERIFY_IS_EQUAL(val(), kZero);
+ val = vec_zero.template maximum<PropagateNumbers>();
+ VERIFY_IS_EQUAL(val(), kZero);
+
+ // Test NaN propagation for tensor of all NaNs.
+ val = vec_all_nan.template minimum<PropagateNaN>();
+ VERIFY((numext::isnan)(val()));
+ val = vec_all_nan.template minimum<PropagateNumbers>();
+ VERIFY_IS_EQUAL(val(), kInf);
+ val = vec_all_nan.template maximum<PropagateNaN>();
+ VERIFY((numext::isnan)(val()));
+ val = vec_all_nan.template maximum<PropagateNumbers>();
+ VERIFY_IS_EQUAL(val(), -kInf);
+
+ // Test NaN propagation for tensor with a single NaN.
+ val = vec_one_nan.template minimum<PropagateNaN>();
+ VERIFY((numext::isnan)(val()));
+ val = vec_one_nan.template minimum<PropagateNumbers>();
+ VERIFY_IS_EQUAL(val(), (size == 1 ? kInf : kZero));
+ val = vec_one_nan.template maximum<PropagateNaN>();
+ VERIFY((numext::isnan)(val()));
+ val = vec_one_nan.template maximum<PropagateNumbers>();
+ VERIFY_IS_EQUAL(val(), (size == 1 ? -kInf : kZero));
}
}