From 3a2dd352ae7137099ce95a28139ffeb1eca23c73 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 11 Jul 2016 13:43:41 -0700 Subject: Improved the contraction mapper to properly support tensor products --- .../CXX11/src/Tensor/TensorContractionMapper.h | 74 +++++++++++----------- 1 file changed, 38 insertions(+), 36 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index b27e1a1b4..9b2cb3ff6 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -130,19 +130,19 @@ class SimpleTensorContractionMapper { } Index contract_val = left ? col : row; - for (int i = static_cast(array_size::value) - 1; i > 0; i--) { - const Index idx = contract_val / m_k_strides[i]; - linidx += idx * m_contract_strides[i]; - contract_val -= idx * m_k_strides[i]; - } - if(array_size::value > 0) { - if (side == Rhs && inner_dim_contiguous) { - eigen_assert(m_contract_strides[0] == 1); - linidx += contract_val; - } else { - linidx += contract_val * m_contract_strides[0]; - } + for (int i = static_cast(array_size::value) - 1; i > 0; i--) { + const Index idx = contract_val / m_k_strides[i]; + linidx += idx * m_contract_strides[i]; + contract_val -= idx * m_k_strides[i]; + } + + if (side == Rhs && inner_dim_contiguous) { + eigen_assert(m_contract_strides[0] == 1); + linidx += contract_val; + } else { + linidx += contract_val * m_contract_strides[0]; + } } return linidx; @@ -153,15 +153,15 @@ class SimpleTensorContractionMapper { const bool left = (side == Lhs); Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; Index linidx[2] = {0, 0}; - for (int i = static_cast(array_size::value) - 1; i > 0; i--) { - const Index idx0 = nocontract_val[0] / m_ij_strides[i]; - const Index idx1 = nocontract_val[1] / m_ij_strides[i]; - linidx[0] += idx0 * m_nocontract_strides[i]; - linidx[1] += idx1 * m_nocontract_strides[i]; - nocontract_val[0] -= idx0 * m_ij_strides[i]; - nocontract_val[1] -= idx1 * m_ij_strides[i]; - } if (array_size::value > array_size::value) { + for (int i = static_cast(array_size::value) - 1; i > 0; i--) { + const Index idx0 = nocontract_val[0] / m_ij_strides[i]; + const Index idx1 = nocontract_val[1] / m_ij_strides[i]; + linidx[0] += idx0 * m_nocontract_strides[i]; + linidx[1] += idx1 * m_nocontract_strides[i]; + nocontract_val[0] -= idx0 * m_ij_strides[i]; + nocontract_val[1] -= idx1 * m_ij_strides[i]; + } if (side == Lhs && inner_dim_contiguous) { eigen_assert(m_nocontract_strides[0] == 1); linidx[0] += nocontract_val[0]; @@ -173,22 +173,24 @@ class SimpleTensorContractionMapper { } Index contract_val[2] = {left ? col : row, left ? col : row + distance}; - for (int i = static_cast(array_size::value) - 1; i > 0; i--) { - const Index idx0 = contract_val[0] / m_k_strides[i]; - const Index idx1 = contract_val[1] / m_k_strides[i]; - linidx[0] += idx0 * m_contract_strides[i]; - linidx[1] += idx1 * m_contract_strides[i]; - contract_val[0] -= idx0 * m_k_strides[i]; - contract_val[1] -= idx1 * m_k_strides[i]; - } + if (array_size::value> 0) { + for (int i = static_cast(array_size::value) - 1; i > 0; i--) { + const Index idx0 = contract_val[0] / m_k_strides[i]; + const Index idx1 = contract_val[1] / m_k_strides[i]; + linidx[0] += idx0 * m_contract_strides[i]; + linidx[1] += idx1 * m_contract_strides[i]; + contract_val[0] -= idx0 * m_k_strides[i]; + contract_val[1] -= idx1 * m_k_strides[i]; + } - if (side == Rhs && inner_dim_contiguous) { - eigen_assert(m_contract_strides[0] == 1); - linidx[0] += contract_val[0]; - linidx[1] += contract_val[1]; - } else { - linidx[0] += contract_val[0] * m_contract_strides[0]; - linidx[1] += contract_val[1] * m_contract_strides[0]; + if (side == Rhs && inner_dim_contiguous) { + eigen_assert(m_contract_strides[0] == 1); + linidx[0] += contract_val[0]; + linidx[1] += contract_val[1]; + } else { + linidx[0] += contract_val[0] * m_contract_strides[0]; + linidx[1] += contract_val[1] * m_contract_strides[0]; + } } return IndexPair(linidx[0], linidx[1]); } @@ -200,7 +202,7 @@ class SimpleTensorContractionMapper { return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { - return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1; + return ((side == Lhs) && inner_dim_contiguous && array_size::value > 0) ? m_contract_strides[0] : 1; } protected: -- cgit v1.2.3