aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2016-06-02 12:41:28 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2016-06-02 12:41:28 -0700
commit811aadbe000a01a47bc53cc228bc269fb0653be5 (patch)
treeb27f35661bf5dece6a9dc4673e6db6a08ebb8772
parent6021c90fdf034cd94502cd1e122407a88ffe6105 (diff)
Add syntactic sugar to Eigen tensors to allow more natural syntax.
Specifically, this enables expressions involving: scalar + tensor scalar * tensor scalar / tensor scalar - tensor
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h30
-rw-r--r--unsupported/test/cxx11_tensor_sugar.cpp34
2 files changed, 57 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index 1eaa8d4fc..12f8a1499 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -216,6 +216,13 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE friend
+ const TensorCwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>
+ operator+ (Scalar lhs, const Derived& rhs) {
+ return rhs + lhs;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sub_op<Scalar>, const Derived>
operator- (Scalar rhs) const {
EIGEN_STATIC_ASSERT((NumTraits<Scalar>::IsSigned || internal::is_same<Scalar, const std::complex<float> >::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
@@ -223,18 +230,41 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE friend
+ const TensorCwiseUnaryOp<internal::scalar_add_op<Scalar>,
+ const TensorCwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const Derived> >
+ operator- (Scalar lhs, const Derived& rhs) {
+ return -rhs + lhs;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived>
operator* (Scalar rhs) const {
return unaryExpr(internal::scalar_multiple_op<Scalar>(rhs));
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE friend
+ const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived>
+ operator* (Scalar lhs, const Derived& rhs) {
+ return rhs * lhs;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_quotient1_op<Scalar>, const Derived>
operator/ (Scalar rhs) const {
return unaryExpr(internal::scalar_quotient1_op<Scalar>(rhs));
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE friend
+ const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>,
+ const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> >
+ operator/ (Scalar lhs, const Derived& rhs) {
+ return rhs.inverse() * lhs;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_mod_op<Scalar>, const Derived>
operator% (Scalar rhs) const {
EIGEN_STATIC_ASSERT(NumTraits<Scalar>::IsInteger, YOU_MADE_A_PROGRAMMING_MISTAKE_TRY_MOD);
diff --git a/unsupported/test/cxx11_tensor_sugar.cpp b/unsupported/test/cxx11_tensor_sugar.cpp
index a03f75cfe..2f56eb495 100644
--- a/unsupported/test/cxx11_tensor_sugar.cpp
+++ b/unsupported/test/cxx11_tensor_sugar.cpp
@@ -33,7 +33,7 @@ static void test_comparison_sugar() {
}
-static void test_scalar_sugar() {
+static void test_scalar_sugar_add_mul() {
Tensor<float, 3> A(6, 7, 5);
Tensor<float, 3> B(6, 7, 5);
A.setRandom();
@@ -41,21 +41,41 @@ static void test_scalar_sugar() {
const float alpha = 0.43f;
const float beta = 0.21f;
+ const float gamma = 0.14f;
- Tensor<float, 3> R = A * A.constant(alpha) + B * B.constant(beta);
- Tensor<float, 3> S = A * alpha + B * beta;
-
- // TODO: add enough syntactic sugar to support this
- // Tensor<float, 3> T = alpha * A + beta * B;
+ Tensor<float, 3> R = A.constant(gamma) + A * A.constant(alpha) + B * B.constant(beta);
+ Tensor<float, 3> S = A * alpha + B * beta + gamma;
+ Tensor<float, 3> T = gamma + alpha * A + beta * B;
for (int i = 0; i < 6*7*5; ++i) {
VERIFY_IS_APPROX(R(i), S(i));
+ VERIFY_IS_APPROX(R(i), T(i));
}
}
+static void test_scalar_sugar_sub_div() {
+ Tensor<float, 3> A(6, 7, 5);
+ Tensor<float, 3> B(6, 7, 5);
+ A.setRandom();
+ B.setRandom();
+
+ const float alpha = 0.43f;
+ const float beta = 0.21f;
+ const float gamma = 0.14f;
+ const float delta = 0.32f;
+
+ Tensor<float, 3> R = A.constant(gamma) - A / A.constant(alpha)
+ - B.constant(beta) / B - A.constant(delta);
+ Tensor<float, 3> S = gamma - A / alpha - beta / B - delta;
+
+ for (int i = 0; i < 6*7*5; ++i) {
+ VERIFY_IS_APPROX(R(i), S(i));
+ }
+}
void test_cxx11_tensor_sugar()
{
CALL_SUBTEST(test_comparison_sugar());
- CALL_SUBTEST(test_scalar_sugar());
+ CALL_SUBTEST(test_scalar_sugar_add_mul());
+ CALL_SUBTEST(test_scalar_sugar_sub_div());
}