diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-13 21:48:31 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-13 21:48:31 +0000 |
commit | c6953f799b01d36f4236b64f351cc1446e0abe17 (patch) | |
tree | 9abcded97c6effc010d08787c5b43ef7bb043b54 /unsupported/test | |
parent | 807e51528d220c0efed870f0505dea81a5776085 (diff) |
Add packet generic ops `predux_fmin`, `predux_fmin_nan`, `predux_fmax`, and `predux_fmax_nan` that implement reductions with `PropagateNaN`, and `PropagateNumbers` semantics. Add (slow) generic implementations for most reductions.
Diffstat (limited to 'unsupported/test')
-rw-r--r-- | unsupported/test/cxx11_tensor_expr.cpp | 94 |
1 files changed, 67 insertions, 27 deletions
diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp index 7fac3b4ed..556d01d4d 100644 --- a/unsupported/test/cxx11_tensor_expr.cpp +++ b/unsupported/test/cxx11_tensor_expr.cpp @@ -302,12 +302,17 @@ static void test_select() template <typename Scalar> void test_minmax_nan_propagation_templ() { for (int size = 1; size < 17; ++size) { - const Scalar kNan = std::numeric_limits<Scalar>::quiet_NaN(); + std::cout << "size = " << size << std::endl; + const Scalar kNaN = std::numeric_limits<Scalar>::quiet_NaN(); + const Scalar kInf = std::numeric_limits<Scalar>::infinity(); const Scalar kZero(0); - Tensor<Scalar, 1> vec_nan(size); + Tensor<Scalar, 1> vec_all_nan(size); + Tensor<Scalar, 1> vec_one_nan(size); Tensor<Scalar, 1> vec_zero(size); - vec_nan.setConstant(kNan); + vec_all_nan.setConstant(kNaN); vec_zero.setZero(); + vec_one_nan.setZero(); + vec_one_nan(size/2) = kNaN; auto verify_all_nan = [&](const Tensor<Scalar, 1>& v) { for (int i = 0; i < size; ++i) { @@ -326,12 +331,12 @@ void test_minmax_nan_propagation_templ() { // max(nan, 0) = nan // max(0, nan) = nan // max(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kNan)); - verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_nan)); - verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kZero)); - verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_zero)); - verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNan)); - verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_all_nan)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(kZero)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNaN>(vec_zero)); + verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNaN)); + verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(kZero)); verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(vec_zero)); @@ -340,12 +345,12 @@ void test_minmax_nan_propagation_templ() { // max(nan, 0) = 0 // max(0, nan) = 0 // max(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(kNan)); - verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(vec_nan)); - verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(kZero)); - verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(vec_zero)); - verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNan)); - verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_all_nan)); + verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(kZero)); + verify_all_zero(vec_all_nan.template cwiseMax<PropagateNumbers>(vec_zero)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNaN)); + verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kZero)); verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_zero)); @@ -354,12 +359,12 @@ void test_minmax_nan_propagation_templ() { // min(nan, 0) = nan // min(0, nan) = nan // min(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kNan)); - verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_nan)); - verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kZero)); - verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_zero)); - verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNan)); - verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_all_nan)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(kZero)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNaN>(vec_zero)); + verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNaN)); + verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(kZero)); verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(vec_zero)); @@ -368,14 +373,49 @@ void test_minmax_nan_propagation_templ() { // min(nan, 0) = 0 // min(0, nan) = 0 // min(0, 0) = 0 - verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(kNan)); - verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(vec_nan)); - verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(kZero)); - verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(vec_zero)); - verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNan)); - verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_nan)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(kNaN)); + verify_all_nan(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_all_nan)); + verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(kZero)); + verify_all_zero(vec_all_nan.template cwiseMin<PropagateNumbers>(vec_zero)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNaN)); + verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_all_nan)); verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kZero)); verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_zero)); + + // Test min and max reduction + Tensor<Scalar, 0> val; + val = vec_zero.minimum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template minimum<PropagateNaN>(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template minimum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.maximum(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template maximum<PropagateNaN>(); + VERIFY_IS_EQUAL(val(), kZero); + val = vec_zero.template maximum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), kZero); + + // Test NaN propagation for tensor of all NaNs. + val = vec_all_nan.template minimum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_all_nan.template minimum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), kInf); + val = vec_all_nan.template maximum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_all_nan.template maximum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), -kInf); + + // Test NaN propagation for tensor with a single NaN. + val = vec_one_nan.template minimum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_one_nan.template minimum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), (size == 1 ? kInf : kZero)); + val = vec_one_nan.template maximum<PropagateNaN>(); + VERIFY((numext::isnan)(val())); + val = vec_one_nan.template maximum<PropagateNumbers>(); + VERIFY_IS_EQUAL(val(), (size == 1 ? -kInf : kZero)); } } |