aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-02-11 16:34:07 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-02-11 16:34:07 -0800
commitde345eff2e7e41505224e04c47e2a91b020b5a5a (patch)
treee671b5a0f9ddd6039d926c34ee9d5221cb26335a
parent17e93ba1488a03daf16ca22dd3a3f7ccae86bedf (diff)
Added a method to conjugate the content of a tensor or the result of a tensor expression.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h6
-rw-r--r--unsupported/test/cxx11_tensor_of_complex.cpp20
2 files changed, 26 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index cca716d6f..4dea1d3a0 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -171,6 +171,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>
+ conjugate() const {
+ return unaryExpr(internal::scalar_conjugate_op<Scalar>());
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_pow_op<Scalar>, const Derived>
pow(Scalar exponent) const {
return unaryExpr(internal::scalar_pow_op<Scalar>(exponent));
diff --git a/unsupported/test/cxx11_tensor_of_complex.cpp b/unsupported/test/cxx11_tensor_of_complex.cpp
index 8ad04f699..25e51143e 100644
--- a/unsupported/test/cxx11_tensor_of_complex.cpp
+++ b/unsupported/test/cxx11_tensor_of_complex.cpp
@@ -48,6 +48,25 @@ static void test_abs()
}
+static void test_conjugate()
+{
+ Tensor<std::complex<float>, 1> data1(3);
+ Tensor<std::complex<double>, 1> data2(3);
+ Tensor<int, 1> data3(3);
+ data1.setRandom();
+ data2.setRandom();
+ data3.setRandom();
+
+ Tensor<std::complex<float>, 1> conj1 = data1.conjugate();
+ Tensor<std::complex<double>, 1> conj2 = data2.conjugate();
+ Tensor<int, 1> conj3 = data3.conjugate();
+ for (int i = 0; i < 3; ++i) {
+ VERIFY_IS_APPROX(conj1(i), std::conj(data1(i)));
+ VERIFY_IS_APPROX(conj2(i), std::conj(data2(i)));
+ VERIFY_IS_APPROX(conj3(i), data3(i));
+ }
+}
+
static void test_contractions()
{
Tensor<std::complex<float>, 4> t_left(30, 50, 8, 31);
@@ -77,5 +96,6 @@ void test_cxx11_tensor_of_complex()
{
CALL_SUBTEST(test_additions());
CALL_SUBTEST(test_abs());
+ CALL_SUBTEST(test_conjugate());
CALL_SUBTEST(test_contractions());
}