diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-06 09:29:13 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-06 09:29:13 -0800 |
commit | 9f98650d0a82d4757afb4503ce6f2b6f61763463 (patch) | |
tree | 7c8f8d64461d5dc97a2e7b975dc56d8f66cc28cb | |
parent | 509e4ddc02e0d70b8c1ee325f3b18624d4235c1e (diff) |
Ensured that contractions that can be reduced to a matrix vector product work correctly even when the input coefficients aren't aligned.
-rw-r--r-- | Eigen/src/Core/products/GeneralMatrixVector.h | 8 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_contraction.cpp | 48 |
2 files changed, 54 insertions, 2 deletions
diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 7dfa48bfb..7df6a6b1a 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -140,10 +140,11 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,C // find how many columns do we have to skip to be aligned with the result (if possible) Index skipColumns = 0; // if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats) - if( (lhsAlignmentOffset < 0) || (size_t(res)%sizeof(ResScalar)) ) + if( (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == size) || (size_t(res)%sizeof(ResScalar)) ) { alignedSize = 0; alignedStart = 0; + alignmentPattern = NoneAligned; } else if(LhsPacketSize > 4) { @@ -412,10 +413,13 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,R // find how many rows do we have to skip to be aligned with rhs (if possible) Index skipRows = 0; // if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats) - if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) || (lhsAlignmentOffset < 0) || (rhsAlignmentOffset < 0) ) + if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) || + (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) || + (rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) ) { alignedSize = 0; alignedStart = 0; + alignmentPattern = NoneAligned; } else if(LhsPacketSize > 4) { diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index 2b599d30d..17bd335f7 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -352,6 +352,52 @@ static void test_large_contraction() } +static void test_matrix_vector() +{ + Tensor<float, 2> t_left(30, 50); + Tensor<float, 1> t_right(50); + Tensor<float, 1> t_result(30); + + t_left.setRandom(); + t_right.setRandom(); + + typedef Map<Eigen::Matrix<float, Dynamic, Dynamic>> MapXf; + MapXf m_left(t_left.data(), 30, 50); + MapXf m_right(t_right.data(), 50, 1); + Eigen::Matrix<float, Dynamic, Dynamic> m_result(30, 1); + + // this contraction should be equivalent to a single matrix multiplication + Eigen::array<DimPair, 1> dims{{DimPair(1, 0)}}; + + // compute results by separate methods + t_result = t_left.contract(t_right, dims); + m_result = m_left * m_right; + + for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) { + VERIFY_IS_APPROX(t_result(i), m_result(i, 0)); + } +} + + +static void test_tensor_vector() +{ + Tensor<float, 3> t_left(7, 13, 17); + Tensor<float, 2> t_right(1, 7); + typedef typename Tensor<float, 1>::DimensionPair DimensionPair; + Eigen::array<DimensionPair, 1> dim_pair01{{{0, 1}}}; + Tensor<float, 3> t_result = t_left.contract(t_right, dim_pair01); + + typedef Map<Eigen::Matrix<float, Dynamic, Dynamic>> MapXf; + MapXf m_left(t_left.data(), 7, 13*17); + MapXf m_right(t_right.data(), 1, 7); + Eigen::Matrix<float, Dynamic, Dynamic> m_result = m_left.transpose() * m_right.transpose(); + + for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) { + VERIFY_IS_APPROX(t_result(i), m_result(i, 0)); + } +} + + void test_cxx11_tensor_contraction() { CALL_SUBTEST(test_evals()); @@ -364,4 +410,6 @@ void test_cxx11_tensor_contraction() CALL_SUBTEST(test_out_of_order_contraction()); CALL_SUBTEST(test_consistency()); CALL_SUBTEST(test_large_contraction()); + CALL_SUBTEST(test_matrix_vector()); + CALL_SUBTEST(test_tensor_vector()); } |