aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_expr.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-05 10:49:34 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-05 10:49:34 -0700
commit8998f4099e20ebc80db0aba2582301cd48d31c5a (patch)
tree18fb2111a71d612cf9e31de79e1c4b7250501fdf /unsupported/test/cxx11_tensor_expr.cpp
parent6fa6cdd2b988da98cbdd2b1a5fd2fd3b9d56a4b1 (diff)
Created additional tests for the tensor code.
Diffstat (limited to 'unsupported/test/cxx11_tensor_expr.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_expr.cpp149
1 files changed, 134 insertions, 15 deletions
diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp
index e0124da8c..e85fcbfa9 100644
--- a/unsupported/test/cxx11_tensor_expr.cpp
+++ b/unsupported/test/cxx11_tensor_expr.cpp
@@ -28,10 +28,10 @@ static void test_1d()
float data3[6];
TensorMap<Tensor<float, 1>> vec3(data3, 6);
- vec3 = vec1.cwiseSqrt();
+ vec3 = vec1.sqrt();
float data4[6];
TensorMap<Tensor<float, 1, RowMajor>> vec4(data4, 6);
- vec4 = vec2.cwiseSqrt();
+ vec4 = vec2.square();
VERIFY_IS_APPROX(vec3(0), sqrtf(4.0));
VERIFY_IS_APPROX(vec3(1), sqrtf(8.0));
@@ -40,12 +40,12 @@ static void test_1d()
VERIFY_IS_APPROX(vec3(4), sqrtf(23.0));
VERIFY_IS_APPROX(vec3(5), sqrtf(42.0));
- VERIFY_IS_APPROX(vec4(0), sqrtf(0.0));
- VERIFY_IS_APPROX(vec4(1), sqrtf(1.0));
- VERIFY_IS_APPROX(vec4(2), sqrtf(2.0));
- VERIFY_IS_APPROX(vec4(3), sqrtf(3.0));
- VERIFY_IS_APPROX(vec4(4), sqrtf(4.0));
- VERIFY_IS_APPROX(vec4(5), sqrtf(5.0));
+ VERIFY_IS_APPROX(vec4(0), 0.0f);
+ VERIFY_IS_APPROX(vec4(1), 1.0f);
+ VERIFY_IS_APPROX(vec4(2), 2.0f * 2.0f);
+ VERIFY_IS_APPROX(vec4(3), 3.0f * 3.0f);
+ VERIFY_IS_APPROX(vec4(4), 4.0f * 4.0f);
+ VERIFY_IS_APPROX(vec4(5), 5.0f * 5.0f);
vec3 = vec1 + vec2;
VERIFY_IS_APPROX(vec3(0), 4.0f + 0.0f);
@@ -79,8 +79,8 @@ static void test_2d()
Tensor<float, 2> mat3(2,3);
Tensor<float, 2, RowMajor> mat4(2,3);
- mat3 = mat1.cwiseAbs();
- mat4 = mat2.cwiseAbs();
+ mat3 = mat1.abs();
+ mat4 = mat2.abs();
VERIFY_IS_APPROX(mat3(0,0), 0.0f);
VERIFY_IS_APPROX(mat3(0,1), 1.0f);
@@ -102,7 +102,7 @@ static void test_3d()
Tensor<float, 3> mat1(2,3,7);
Tensor<float, 3, RowMajor> mat2(2,3,7);
- float val = 0.0;
+ float val = 1.0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
@@ -118,28 +118,147 @@ static void test_3d()
Tensor<float, 3, RowMajor> mat4(2,3,7);
mat4 = mat2 * 3.14f;
Tensor<float, 3> mat5(2,3,7);
- mat5 = mat1.cwiseSqrt().cwiseSqrt();
+ mat5 = mat1.inverse().log();
Tensor<float, 3, RowMajor> mat6(2,3,7);
- mat6 = mat2.cwiseSqrt() * 3.14f;
+ mat6 = mat2.pow(0.5f) * 3.14f;
+ Tensor<float, 3> mat7(2,3,7);
+ mat7 = mat1.cwiseMax(mat5 * 2.0f).exp();
+ Tensor<float, 3, RowMajor> mat8(2,3,7);
+ mat8 = (-mat2).exp() * 3.14f;
- val = 0.0;
+ val = 1.0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_APPROX(mat3(i,j,k), val + val);
VERIFY_IS_APPROX(mat4(i,j,k), val * 3.14f);
- VERIFY_IS_APPROX(mat5(i,j,k), sqrtf(sqrtf(val)));
+ VERIFY_IS_APPROX(mat5(i,j,k), logf(1.0f/val));
VERIFY_IS_APPROX(mat6(i,j,k), sqrtf(val) * 3.14f);
+ VERIFY_IS_APPROX(mat7(i,j,k), expf((std::max)(val, mat5(i,j,k) * 2.0f)));
+ VERIFY_IS_APPROX(mat8(i,j,k), expf(-val) * 3.14f);
val += 1.0;
}
}
}
}
+static void test_constants()
+{
+ Tensor<float, 3> mat1(2,3,7);
+ Tensor<float, 3> mat2(2,3,7);
+ Tensor<float, 3> mat3(2,3,7);
+
+ float val = 1.0;
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ mat1(i,j,k) = val;
+ val += 1.0;
+ }
+ }
+ }
+ mat2 = mat1.constant(3.14f);
+ mat3 = mat1.cwiseMax(7.3f).exp();
+
+ val = 1.0;
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ VERIFY_IS_APPROX(mat2(i,j,k), 3.14f);
+ VERIFY_IS_APPROX(mat3(i,j,k), expf((std::max)(val, 7.3f)));
+ val += 1.0;
+ }
+ }
+ }
+}
+
+
+static void test_functors()
+{
+ Tensor<float, 3> mat1(2,3,7);
+ Tensor<float, 3> mat2(2,3,7);
+ Tensor<float, 3> mat3(2,3,7);
+
+ float val = 1.0;
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ mat1(i,j,k) = val;
+ val += 1.0;
+ }
+ }
+ }
+ mat2 = mat1.inverse().unaryExpr(&asinf);
+ mat3 = mat1.unaryExpr(&tanhf);
+
+ val = 1.0;
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ VERIFY_IS_APPROX(mat2(i,j,k), asinf(1.0f / mat1(i,j,k)));
+ VERIFY_IS_APPROX(mat3(i,j,k), tanhf(mat1(i,j,k)));
+ val += 1.0;
+ }
+ }
+ }
+}
+
+static void test_type_casting()
+{
+ Tensor<bool, 3> mat1(2,3,7);
+ Tensor<float, 3> mat2(2,3,7);
+ Tensor<double, 3> mat3(2,3,7);
+ mat1.setRandom();
+ mat2.setRandom();
+
+ mat3 = mat1.template cast<double>();
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ VERIFY_IS_APPROX(mat3(i,j,k), mat1(i,j,k) ? 1.0 : 0.0);
+ }
+ }
+ }
+
+ mat3 = mat2.template cast<double>();
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ VERIFY_IS_APPROX(mat3(i,j,k), static_cast<double>(mat2(i,j,k)));
+ }
+ }
+ }
+}
+
+static void test_select()
+{
+ Tensor<float, 3> selector(2,3,7);
+ Tensor<float, 3> mat1(2,3,7);
+ Tensor<float, 3> mat2(2,3,7);
+ Tensor<float, 3> result(2,3,7);
+
+ selector.setRandom();
+ mat1.setRandom();
+ mat2.setRandom();
+ result = (selector > selector.constant(0.5f)).select(mat1, mat2);
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 7; ++k) {
+ VERIFY_IS_APPROX(result(i,j,k), (selector(i,j,k) > 0.5f) ? mat1(i,j,k) : mat2(i,j,k));
+ }
+ }
+ }
+}
+
void test_cxx11_tensor_expr()
{
CALL_SUBTEST(test_1d());
CALL_SUBTEST(test_2d());
CALL_SUBTEST(test_3d());
+ CALL_SUBTEST(test_constants());
+ CALL_SUBTEST(test_functors());
+ CALL_SUBTEST(test_type_casting());
+ CALL_SUBTEST(test_select());
}