From 811aadbe000a01a47bc53cc228bc269fb0653be5 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Thu, 2 Jun 2016 12:41:28 -0700 Subject: Add syntactic sugar to Eigen tensors to allow more natural syntax. Specifically, this enables expressions involving: scalar + tensor scalar * tensor scalar / tensor scalar - tensor --- unsupported/test/cxx11_tensor_sugar.cpp | 34 ++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) (limited to 'unsupported/test/cxx11_tensor_sugar.cpp') diff --git a/unsupported/test/cxx11_tensor_sugar.cpp b/unsupported/test/cxx11_tensor_sugar.cpp index a03f75cfe..2f56eb495 100644 --- a/unsupported/test/cxx11_tensor_sugar.cpp +++ b/unsupported/test/cxx11_tensor_sugar.cpp @@ -33,7 +33,7 @@ static void test_comparison_sugar() { } -static void test_scalar_sugar() { +static void test_scalar_sugar_add_mul() { Tensor A(6, 7, 5); Tensor B(6, 7, 5); A.setRandom(); @@ -41,21 +41,41 @@ static void test_scalar_sugar() { const float alpha = 0.43f; const float beta = 0.21f; + const float gamma = 0.14f; - Tensor R = A * A.constant(alpha) + B * B.constant(beta); - Tensor S = A * alpha + B * beta; - - // TODO: add enough syntactic sugar to support this - // Tensor T = alpha * A + beta * B; + Tensor R = A.constant(gamma) + A * A.constant(alpha) + B * B.constant(beta); + Tensor S = A * alpha + B * beta + gamma; + Tensor T = gamma + alpha * A + beta * B; for (int i = 0; i < 6*7*5; ++i) { VERIFY_IS_APPROX(R(i), S(i)); + VERIFY_IS_APPROX(R(i), T(i)); } } +static void test_scalar_sugar_sub_div() { + Tensor A(6, 7, 5); + Tensor B(6, 7, 5); + A.setRandom(); + B.setRandom(); + + const float alpha = 0.43f; + const float beta = 0.21f; + const float gamma = 0.14f; + const float delta = 0.32f; + + Tensor R = A.constant(gamma) - A / A.constant(alpha) + - B.constant(beta) / B - A.constant(delta); + Tensor S = gamma - A / alpha - beta / B - delta; + + for (int i = 0; i < 6*7*5; ++i) { + VERIFY_IS_APPROX(R(i), S(i)); + } +} void test_cxx11_tensor_sugar() { CALL_SUBTEST(test_comparison_sugar()); - CALL_SUBTEST(test_scalar_sugar()); + CALL_SUBTEST(test_scalar_sugar_add_mul()); + CALL_SUBTEST(test_scalar_sugar_sub_div()); } -- cgit v1.2.3