aboutsummaryrefslogtreecommitdiffhomepage
path: root/test/gpu_basic.cu
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2021-01-06 09:41:15 -0800
committerGravatar Antonio Sánchez <cantonios@google.com>2021-01-22 18:19:19 +0000
commitf19bcffee6b8018ca101ceb370e6e550a940289f (patch)
tree36447572f9f35914470c66811e613c20bc4e044e /test/gpu_basic.cu
parent65e2169c4521660d30f4d90df61da5f3dd9f45bd (diff)
Specialize std::complex operators for use on GPU device.
NVCC and older versions of clang do not fully support `std::complex` on device, leading to either compile errors (Cannot call `__host__` function) or worse, runtime errors (Illegal instruction). For most functions, we can implement specialized `numext` versions. Here we specialize the standard operators (with the exception of stream operators and member function operators with a scalar that are already specialized in `<complex>`) so they can be used in device code as well. To import these operators into the current scope, use `EIGEN_USING_STD_COMPLEX_OPERATORS`. By default, these are imported into the `Eigen`, `Eigen:internal`, and `Eigen::numext` namespaces. This allow us to remove specializations of the sum/difference/product/quotient ops, and allow us to treat complex numbers like most other scalars (e.g. in tests).
Diffstat (limited to 'test/gpu_basic.cu')
-rw-r--r--test/gpu_basic.cu112
1 files changed, 112 insertions, 0 deletions
diff --git a/test/gpu_basic.cu b/test/gpu_basic.cu
index b82b94d9b..46e4a436f 100644
--- a/test/gpu_basic.cu
+++ b/test/gpu_basic.cu
@@ -107,6 +107,116 @@ struct complex_sqrt {
};
template<typename T>
+struct complex_operators {
+ EIGEN_DEVICE_FUNC
+ void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
+ {
+ using namespace Eigen;
+ typedef typename T::Scalar ComplexType;
+ typedef typename T::Scalar::value_type ValueType;
+ const int num_scalar_operators = 24;
+ const int num_vector_operators = 23; // no unary + operator.
+ int out_idx = i * (num_scalar_operators + num_vector_operators * T::MaxSizeAtCompileTime);
+
+ // Scalar operators.
+ const ComplexType a = in[i];
+ const ComplexType b = in[i + 1];
+
+ out[out_idx++] = +a;
+ out[out_idx++] = -a;
+
+ out[out_idx++] = a + b;
+ out[out_idx++] = a + numext::real(b);
+ out[out_idx++] = numext::real(a) + b;
+ out[out_idx++] = a - b;
+ out[out_idx++] = a - numext::real(b);
+ out[out_idx++] = numext::real(a) - b;
+ out[out_idx++] = a * b;
+ out[out_idx++] = a * numext::real(b);
+ out[out_idx++] = numext::real(a) * b;
+ out[out_idx++] = a / b;
+ out[out_idx++] = a / numext::real(b);
+ out[out_idx++] = numext::real(a) / b;
+
+ out[out_idx] = a; out[out_idx++] += b;
+ out[out_idx] = a; out[out_idx++] -= b;
+ out[out_idx] = a; out[out_idx++] *= b;
+ out[out_idx] = a; out[out_idx++] /= b;
+
+ const ComplexType true_value = ComplexType(ValueType(1), ValueType(0));
+ const ComplexType false_value = ComplexType(ValueType(0), ValueType(0));
+ out[out_idx++] = (a == b ? true_value : false_value);
+ out[out_idx++] = (a == numext::real(b) ? true_value : false_value);
+ out[out_idx++] = (numext::real(a) == b ? true_value : false_value);
+ out[out_idx++] = (a != b ? true_value : false_value);
+ out[out_idx++] = (a != numext::real(b) ? true_value : false_value);
+ out[out_idx++] = (numext::real(a) != b ? true_value : false_value);
+
+ // Vector versions.
+ T x1(in + i);
+ T x2(in + i + 1);
+ const int res_size = T::MaxSizeAtCompileTime * num_scalar_operators;
+ const int size = T::MaxSizeAtCompileTime;
+ int block_idx = 0;
+
+ Map<VectorX<ComplexType>> res(out + out_idx, res_size);
+ res.segment(block_idx, size) = -x1;
+ block_idx += size;
+
+ res.segment(block_idx, size) = x1 + x2;
+ block_idx += size;
+ res.segment(block_idx, size) = x1 + x2.real();
+ block_idx += size;
+ res.segment(block_idx, size) = x1.real() + x2;
+ block_idx += size;
+ res.segment(block_idx, size) = x1 - x2;
+ block_idx += size;
+ res.segment(block_idx, size) = x1 - x2.real();
+ block_idx += size;
+ res.segment(block_idx, size) = x1.real() - x2;
+ block_idx += size;
+ res.segment(block_idx, size) = x1.array() * x2.array();
+ block_idx += size;
+ res.segment(block_idx, size) = x1.array() * x2.real().array();
+ block_idx += size;
+ res.segment(block_idx, size) = x1.real().array() * x2.array();
+ block_idx += size;
+ res.segment(block_idx, size) = x1.array() / x2.array();
+ block_idx += size;
+ res.segment(block_idx, size) = x1.array() / x2.real().array();
+ block_idx += size;
+ res.segment(block_idx, size) = x1.real().array() / x2.array();
+ block_idx += size;
+
+ res.segment(block_idx, size) = x1; res.segment(block_idx, size) += x2;
+ block_idx += size;
+ res.segment(block_idx, size) = x1; res.segment(block_idx, size) -= x2;
+ block_idx += size;
+ res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() *= x2.array();
+ block_idx += size;
+ res.segment(block_idx, size) = x1; res.segment(block_idx, size).array() /= x2.array();
+ block_idx += size;
+
+ // Equality comparisons currently not functional on device
+ // (std::equal_to<T> is host-only).
+ // const T true_vector = T::Constant(true_value);
+ // const T false_vector = T::Constant(false_value);
+ // res.segment(block_idx, size) = (x1 == x2 ? true_vector : false_vector);
+ // block_idx += size;
+ // res.segment(block_idx, size) = (x1 == x2.real() ? true_vector : false_vector);
+ // block_idx += size;
+ // res.segment(block_idx, size) = (x1.real() == x2 ? true_vector : false_vector);
+ // block_idx += size;
+ // res.segment(block_idx, size) = (x1 != x2 ? true_vector : false_vector);
+ // block_idx += size;
+ // res.segment(block_idx, size) = (x1 != x2.real() ? true_vector : false_vector);
+ // block_idx += size;
+ // res.segment(block_idx, size) = (x1.real() != x2 ? true_vector : false_vector);
+ // block_idx += size;
+ }
+};
+
+template<typename T>
struct replicate {
EIGEN_DEVICE_FUNC
void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
@@ -297,6 +407,8 @@ EIGEN_DECLARE_TEST(gpu_basic)
CALL_SUBTEST( run_and_compare_to_gpu(eigenvalues_direct<Matrix3f>(), nthreads, in, out) );
CALL_SUBTEST( run_and_compare_to_gpu(eigenvalues_direct<Matrix2f>(), nthreads, in, out) );
+ // Test std::complex.
+ CALL_SUBTEST( run_and_compare_to_gpu(complex_operators<Vector3cf>(), nthreads, cfin, cfout) );
CALL_SUBTEST( test_with_infs_nans(complex_sqrt<Vector3cf>(), nthreads, cfin, cfout) );
#if defined(__NVCC__)