aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_contraction.cpp
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-09-04 20:27:28 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-09-04 20:27:28 -0700
commit1abe4ed14c0012d85e833c5f507f282cf26edc36 (patch)
treec9d8e8fc6f6fdcba6d3101a2e3baf5634ebffd8c /unsupported/test/cxx11_tensor_contraction.cpp
parentd43f737b4ad52e84a3b4d954d9bfb4c40cf9e819 (diff)
Created more regression tests
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_contraction.cpp166
1 files changed, 166 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp
index fc67d500b..a37fcd967 100644
--- a/unsupported/test/cxx11_tensor_contraction.cpp
+++ b/unsupported/test/cxx11_tensor_contraction.cpp
@@ -141,6 +141,66 @@ static void test_multidims()
}
+static void test_holes() {
+ Tensor<float, 4> t1(2, 5, 7, 3);
+ Tensor<float, 5> t2(2, 7, 11, 13, 3);
+ t1.setRandom();
+ t2.setRandom();
+
+ Eigen::array<DimPair, 2> dims({{DimPair(0, 0), DimPair(3, 4)}});
+ Tensor<float, 5> result = t1.contract(t2, dims);
+ VERIFY_IS_EQUAL(result.dimension(0), 5);
+ VERIFY_IS_EQUAL(result.dimension(1), 7);
+ VERIFY_IS_EQUAL(result.dimension(2), 7);
+ VERIFY_IS_EQUAL(result.dimension(3), 11);
+ VERIFY_IS_EQUAL(result.dimension(4), 13);
+
+ for (int i = 0; i < 5; ++i) {
+ for (int j = 0; j < 5; ++j) {
+ for (int k = 0; k < 5; ++k) {
+ for (int l = 0; l < 5; ++l) {
+ for (int m = 0; m < 5; ++m) {
+ VERIFY_IS_APPROX(result(i, j, k, l, m),
+ t1(0, i, j, 0) * t2(0, k, l, m, 0) +
+ t1(1, i, j, 0) * t2(1, k, l, m, 0) +
+ t1(0, i, j, 1) * t2(0, k, l, m, 1) +
+ t1(1, i, j, 1) * t2(1, k, l, m, 1) +
+ t1(0, i, j, 2) * t2(0, k, l, m, 2) +
+ t1(1, i, j, 2) * t2(1, k, l, m, 2));
+ }
+ }
+ }
+ }
+ }
+}
+
+
+static void test_full_redux()
+{
+ Tensor<float, 2> t1(2, 2);
+ Tensor<float, 3> t2(2, 2, 2);
+ t1.setRandom();
+ t2.setRandom();
+
+ Eigen::array<DimPair, 2> dims({{DimPair(0, 0), DimPair(1, 1)}});
+ Tensor<float, 1> result = t1.contract(t2, dims);
+ VERIFY_IS_EQUAL(result.dimension(0), 2);
+ VERIFY_IS_APPROX(result(0), t1(0, 0) * t2(0, 0, 0) + t1(1, 0) * t2(1, 0, 0)
+ + t1(0, 1) * t2(0, 1, 0) + t1(1, 1) * t2(1, 1, 0));
+ VERIFY_IS_APPROX(result(1), t1(0, 0) * t2(0, 0, 1) + t1(1, 0) * t2(1, 0, 1)
+ + t1(0, 1) * t2(0, 1, 1) + t1(1, 1) * t2(1, 1, 1));
+
+ dims[0] = DimPair(1, 0);
+ dims[1] = DimPair(2, 1);
+ result = t2.contract(t1, dims);
+ VERIFY_IS_EQUAL(result.dimension(0), 2);
+ VERIFY_IS_APPROX(result(0), t1(0, 0) * t2(0, 0, 0) + t1(1, 0) * t2(0, 1, 0)
+ + t1(0, 1) * t2(0, 0, 1) + t1(1, 1) * t2(0, 1, 1));
+ VERIFY_IS_APPROX(result(1), t1(0, 0) * t2(1, 0, 0) + t1(1, 0) * t2(1, 1, 0)
+ + t1(0, 1) * t2(1, 0, 1) + t1(1, 1) * t2(1, 1, 1));
+}
+
+
static void test_expr()
{
Tensor<float, 2> mat1(2, 3);
@@ -160,10 +220,116 @@ static void test_expr()
}
+static void test_out_of_order_contraction()
+{
+ Tensor<float, 3> mat1(2, 2, 2);
+ Tensor<float, 3> mat2(2, 2, 2);
+
+ mat1.setRandom();
+ mat2.setRandom();
+
+ Tensor<float, 2> mat3(2, 2);
+
+ Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(0, 2)}});
+ mat3 = mat1.contract(mat2, dims);
+
+ VERIFY_IS_APPROX(mat3(0, 0),
+ mat1(0,0,0)*mat2(0,0,0) + mat1(1,0,0)*mat2(0,0,1) +
+ mat1(0,0,1)*mat2(1,0,0) + mat1(1,0,1)*mat2(1,0,1));
+ VERIFY_IS_APPROX(mat3(1, 0),
+ mat1(0,1,0)*mat2(0,0,0) + mat1(1,1,0)*mat2(0,0,1) +
+ mat1(0,1,1)*mat2(1,0,0) + mat1(1,1,1)*mat2(1,0,1));
+ VERIFY_IS_APPROX(mat3(0, 1),
+ mat1(0,0,0)*mat2(0,1,0) + mat1(1,0,0)*mat2(0,1,1) +
+ mat1(0,0,1)*mat2(1,1,0) + mat1(1,0,1)*mat2(1,1,1));
+ VERIFY_IS_APPROX(mat3(1, 1),
+ mat1(0,1,0)*mat2(0,1,0) + mat1(1,1,0)*mat2(0,1,1) +
+ mat1(0,1,1)*mat2(1,1,0) + mat1(1,1,1)*mat2(1,1,1));
+
+ Eigen::array<DimPair, 2> dims2({{DimPair(0, 2), DimPair(2, 0)}});
+ mat3 = mat1.contract(mat2, dims2);
+
+ VERIFY_IS_APPROX(mat3(0, 0),
+ mat1(0,0,0)*mat2(0,0,0) + mat1(1,0,0)*mat2(0,0,1) +
+ mat1(0,0,1)*mat2(1,0,0) + mat1(1,0,1)*mat2(1,0,1));
+ VERIFY_IS_APPROX(mat3(1, 0),
+ mat1(0,1,0)*mat2(0,0,0) + mat1(1,1,0)*mat2(0,0,1) +
+ mat1(0,1,1)*mat2(1,0,0) + mat1(1,1,1)*mat2(1,0,1));
+ VERIFY_IS_APPROX(mat3(0, 1),
+ mat1(0,0,0)*mat2(0,1,0) + mat1(1,0,0)*mat2(0,1,1) +
+ mat1(0,0,1)*mat2(1,1,0) + mat1(1,0,1)*mat2(1,1,1));
+ VERIFY_IS_APPROX(mat3(1, 1),
+ mat1(0,1,0)*mat2(0,1,0) + mat1(1,1,0)*mat2(0,1,1) +
+ mat1(0,1,1)*mat2(1,1,0) + mat1(1,1,1)*mat2(1,1,1));
+
+}
+
+
+static void test_consistency()
+{
+ // this does something like testing (A*B)^T = (B^T * A^T)
+
+ Tensor<float, 3> mat1(4, 3, 5);
+ Tensor<float, 5> mat2(3, 2, 1, 5, 4);
+ mat1.setRandom();
+ mat2.setRandom();
+
+ Tensor<float, 4> mat3(5, 2, 1, 5);
+ Tensor<float, 4> mat4(2, 1, 5, 5);
+
+ // contract on dimensions of size 4 and 3
+ Eigen::array<DimPair, 2> dims1({{DimPair(0, 4), DimPair(1, 0)}});
+ Eigen::array<DimPair, 2> dims2({{DimPair(4, 0), DimPair(0, 1)}});
+
+ mat3 = mat1.contract(mat2, dims1);
+ mat4 = mat2.contract(mat1, dims2);
+
+ // check that these are equal except for ordering of dimensions
+ for (size_t i = 0; i < 5; i++) {
+ for (size_t j = 0; j < 10; j++) {
+ VERIFY_IS_APPROX(mat3.data()[i + 5 * j], mat4.data()[j + 10 * i]);
+ }
+ }
+}
+
+
+static void test_large_contraction()
+{
+ Tensor<float, 4> t_left(30, 50, 8, 31);
+ Tensor<float, 5> t_right(8, 31, 7, 20, 10);
+ Tensor<float, 5> t_result(30, 50, 7, 20, 10);
+
+ t_left.setRandom();
+ t_right.setRandom();
+
+ typedef Map<MatrixXf> MapXf;
+ MapXf m_left(t_left.data(), 1500, 248);
+ MapXf m_right(t_right.data(), 248, 1400);
+ MatrixXf m_result(1500, 1400);
+
+ // this contraction should be equivalent to a single matrix multiplication
+ Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}});
+
+ // compute results by separate methods
+ t_result = t_left.contract(t_right, dims);
+ m_result = m_left * m_right;
+
+ for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
+ VERIFY(&t_result.data()[i] != &m_result.data()[i]);
+ VERIFY_IS_APPROX(t_result.data()[i], m_result.data()[i]);
+ }
+}
+
+
void test_cxx11_tensor_contraction()
{
CALL_SUBTEST(test_evals());
CALL_SUBTEST(test_scalar());
CALL_SUBTEST(test_multidims());
+ CALL_SUBTEST(test_holes());
+ CALL_SUBTEST(test_full_redux());
CALL_SUBTEST(test_expr());
+ CALL_SUBTEST(test_out_of_order_contraction());
+ CALL_SUBTEST(test_consistency());
+ CALL_SUBTEST(test_large_contraction());
}