diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_contraction.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_contraction.cpp | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index 0e16308a2..ace97057f 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -87,19 +87,14 @@ static void test_scalar() vec1.setRandom(); vec2.setRandom(); - Tensor<float, 1, DataLayout> scalar(1); - scalar.setZero(); Eigen::array<DimPair, 1> dims = {{DimPair(0, 0)}}; - typedef TensorEvaluator<decltype(vec1.contract(vec2, dims)), DefaultDevice> Evaluator; - Evaluator eval(vec1.contract(vec2, dims), DefaultDevice()); - eval.evalTo(scalar.data()); - EIGEN_STATIC_ASSERT(Evaluator::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE); + Tensor<float, 0, DataLayout> scalar = vec1.contract(vec2, dims); float expected = 0.0f; for (int i = 0; i < 6; ++i) { expected += vec1(i) * vec2(i); } - VERIFY_IS_APPROX(scalar(0), expected); + VERIFY_IS_APPROX(scalar(), expected); } template<int DataLayout> @@ -494,6 +489,27 @@ static void test_tensor_product() } +template<int DataLayout> +static void test_const_inputs() +{ + Tensor<float, 2, DataLayout> in1(2, 3); + Tensor<float, 2, DataLayout> in2(3, 2); + in1.setRandom(); + in2.setRandom(); + + TensorMap<Tensor<const float, 2, DataLayout> > mat1(in1.data(), 2, 3); + TensorMap<Tensor<const float, 2, DataLayout> > mat2(in2.data(), 3, 2); + Tensor<float, 2, DataLayout> mat3(2,2); + + Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}}; + mat3 = mat1.contract(mat2, dims); + + VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0)); + VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1)); + VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0)); + VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1)); +} + void test_cxx11_tensor_contraction() { CALL_SUBTEST(test_evals<ColMajor>()); @@ -524,4 +540,6 @@ void test_cxx11_tensor_contraction() CALL_SUBTEST(test_small_blocking_factors<RowMajor>()); CALL_SUBTEST(test_tensor_product<ColMajor>()); CALL_SUBTEST(test_tensor_product<RowMajor>()); + CALL_SUBTEST(test_const_inputs<ColMajor>()); + CALL_SUBTEST(test_const_inputs<RowMajor>()); } |