aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_sugar.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/test/cxx11_tensor_sugar.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_sugar.cpp25
1 files changed, 24 insertions, 1 deletions
diff --git a/unsupported/test/cxx11_tensor_sugar.cpp b/unsupported/test/cxx11_tensor_sugar.cpp
index 98671a986..a03f75cfe 100644
--- a/unsupported/test/cxx11_tensor_sugar.cpp
+++ b/unsupported/test/cxx11_tensor_sugar.cpp
@@ -18,7 +18,7 @@ static void test_comparison_sugar() {
#define TEST_TENSOR_EQUAL(e1, e2) \
b = ((e1) == (e2)).all(); \
- VERIFY(b(0))
+ VERIFY(b())
#define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0))
@@ -32,7 +32,30 @@ static void test_comparison_sugar() {
#undef TEST_TENSOR_EQUAL
}
+
+static void test_scalar_sugar() {
+ Tensor<float, 3> A(6, 7, 5);
+ Tensor<float, 3> B(6, 7, 5);
+ A.setRandom();
+ B.setRandom();
+
+ const float alpha = 0.43f;
+ const float beta = 0.21f;
+
+ Tensor<float, 3> R = A * A.constant(alpha) + B * B.constant(beta);
+ Tensor<float, 3> S = A * alpha + B * beta;
+
+ // TODO: add enough syntactic sugar to support this
+ // Tensor<float, 3> T = alpha * A + beta * B;
+
+ for (int i = 0; i < 6*7*5; ++i) {
+ VERIFY_IS_APPROX(R(i), S(i));
+ }
+}
+
+
void test_cxx11_tensor_sugar()
{
CALL_SUBTEST(test_comparison_sugar());
+ CALL_SUBTEST(test_scalar_sugar());
}