diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2016-06-02 12:41:28 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2016-06-02 12:41:28 -0700 |
commit | 811aadbe000a01a47bc53cc228bc269fb0653be5 (patch) | |
tree | b27f35661bf5dece6a9dc4673e6db6a08ebb8772 | |
parent | 6021c90fdf034cd94502cd1e122407a88ffe6105 (diff) |
Add syntactic sugar to Eigen tensors to allow more natural syntax.
Specifically, this enables expressions involving:
scalar + tensor
scalar * tensor
scalar / tensor
scalar - tensor
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 30 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_sugar.cpp | 34 |
2 files changed, 57 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 1eaa8d4fc..12f8a1499 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -216,6 +216,13 @@ class TensorBase<Derived, ReadOnlyAccessors> } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE friend + const TensorCwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived> + operator+ (Scalar lhs, const Derived& rhs) { + return rhs + lhs; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sub_op<Scalar>, const Derived> operator- (Scalar rhs) const { EIGEN_STATIC_ASSERT((NumTraits<Scalar>::IsSigned || internal::is_same<Scalar, const std::complex<float> >::value), YOU_MADE_A_PROGRAMMING_MISTAKE); @@ -223,18 +230,41 @@ class TensorBase<Derived, ReadOnlyAccessors> } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE friend + const TensorCwiseUnaryOp<internal::scalar_add_op<Scalar>, + const TensorCwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const Derived> > + operator- (Scalar lhs, const Derived& rhs) { + return -rhs + lhs; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived> operator* (Scalar rhs) const { return unaryExpr(internal::scalar_multiple_op<Scalar>(rhs)); } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE friend + const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived> + operator* (Scalar lhs, const Derived& rhs) { + return rhs * lhs; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_quotient1_op<Scalar>, const Derived> operator/ (Scalar rhs) const { return unaryExpr(internal::scalar_quotient1_op<Scalar>(rhs)); } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE friend + const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, + const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> > + operator/ (Scalar lhs, const Derived& rhs) { + return rhs.inverse() * lhs; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_mod_op<Scalar>, const Derived> operator% (Scalar rhs) const { EIGEN_STATIC_ASSERT(NumTraits<Scalar>::IsInteger, YOU_MADE_A_PROGRAMMING_MISTAKE_TRY_MOD); diff --git a/unsupported/test/cxx11_tensor_sugar.cpp b/unsupported/test/cxx11_tensor_sugar.cpp index a03f75cfe..2f56eb495 100644 --- a/unsupported/test/cxx11_tensor_sugar.cpp +++ b/unsupported/test/cxx11_tensor_sugar.cpp @@ -33,7 +33,7 @@ static void test_comparison_sugar() { } -static void test_scalar_sugar() { +static void test_scalar_sugar_add_mul() { Tensor<float, 3> A(6, 7, 5); Tensor<float, 3> B(6, 7, 5); A.setRandom(); @@ -41,21 +41,41 @@ static void test_scalar_sugar() { const float alpha = 0.43f; const float beta = 0.21f; + const float gamma = 0.14f; - Tensor<float, 3> R = A * A.constant(alpha) + B * B.constant(beta); - Tensor<float, 3> S = A * alpha + B * beta; - - // TODO: add enough syntactic sugar to support this - // Tensor<float, 3> T = alpha * A + beta * B; + Tensor<float, 3> R = A.constant(gamma) + A * A.constant(alpha) + B * B.constant(beta); + Tensor<float, 3> S = A * alpha + B * beta + gamma; + Tensor<float, 3> T = gamma + alpha * A + beta * B; for (int i = 0; i < 6*7*5; ++i) { VERIFY_IS_APPROX(R(i), S(i)); + VERIFY_IS_APPROX(R(i), T(i)); } } +static void test_scalar_sugar_sub_div() { + Tensor<float, 3> A(6, 7, 5); + Tensor<float, 3> B(6, 7, 5); + A.setRandom(); + B.setRandom(); + + const float alpha = 0.43f; + const float beta = 0.21f; + const float gamma = 0.14f; + const float delta = 0.32f; + + Tensor<float, 3> R = A.constant(gamma) - A / A.constant(alpha) + - B.constant(beta) / B - A.constant(delta); + Tensor<float, 3> S = gamma - A / alpha - beta / B - delta; + + for (int i = 0; i < 6*7*5; ++i) { + VERIFY_IS_APPROX(R(i), S(i)); + } +} void test_cxx11_tensor_sugar() { CALL_SUBTEST(test_comparison_sugar()); - CALL_SUBTEST(test_scalar_sugar()); + CALL_SUBTEST(test_scalar_sugar_add_mul()); + CALL_SUBTEST(test_scalar_sugar_sub_div()); } |