aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Godeffroy Valet <godeffroy.valet@m4x.org>2015-07-25 11:58:36 +0200
committerGravatar Godeffroy Valet <godeffroy.valet@m4x.org>2015-07-25 11:58:36 +0200
commit2195822df64c34eacc411043a197ce701ae6b135 (patch)
treeb64e793502bc561b4d443f0352115d3959a6413b /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent4b3052c54d14f486c2665530f06cdeff636169e4 (diff)
Allowed tensor contraction operation with an empty array of dimension pairs, which performs a tensor product.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h30
1 files changed, 16 insertions, 14 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 59ae4a2d0..85ae9dd6a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -66,7 +66,7 @@ class BaseTensorContractionMapper {
const bool left = (side == Lhs);
Index nocontract_val = left ? row : col;
Index linidx = 0;
- for (int i = array_size<nocontract_t>::value - 1; i > 0; i--) {
+ for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
const Index idx = nocontract_val / m_ij_strides[i];
linidx += idx * m_nocontract_strides[i];
nocontract_val -= idx * m_ij_strides[i];
@@ -81,17 +81,19 @@ class BaseTensorContractionMapper {
}
Index contract_val = left ? col : row;
- for (int i = array_size<contract_t>::value - 1; i > 0; i--) {
+ for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
const Index idx = contract_val / m_k_strides[i];
linidx += idx * m_contract_strides[i];
contract_val -= idx * m_k_strides[i];
}
- EIGEN_STATIC_ASSERT(array_size<contract_t>::value > 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
- if (side == Rhs && inner_dim_contiguous) {
- eigen_assert(m_contract_strides[0] == 1);
- linidx += contract_val;
- } else {
- linidx += contract_val * m_contract_strides[0];
+
+ if(array_size<contract_t>::value > 0) {
+ if (side == Rhs && inner_dim_contiguous) {
+ eigen_assert(m_contract_strides[0] == 1);
+ linidx += contract_val;
+ } else {
+ linidx += contract_val * m_contract_strides[0];
+ }
}
return linidx;
@@ -102,7 +104,7 @@ class BaseTensorContractionMapper {
const bool left = (side == Lhs);
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
Index linidx[2] = {0, 0};
- for (int i = array_size<nocontract_t>::value - 1; i > 0; i--) {
+ for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
const Index idx0 = nocontract_val[0] / m_ij_strides[i];
const Index idx1 = nocontract_val[1] / m_ij_strides[i];
linidx[0] += idx0 * m_nocontract_strides[i];
@@ -122,7 +124,7 @@ class BaseTensorContractionMapper {
}
Index contract_val[2] = {left ? col : row, left ? col : row + distance};
- for (int i = array_size<contract_t>::value - 1; i > 0; i--) {
+ for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
const Index idx0 = contract_val[0] / m_k_strides[i];
const Index idx1 = contract_val[1] / m_k_strides[i];
linidx[0] += idx0 * m_contract_strides[i];
@@ -130,7 +132,7 @@ class BaseTensorContractionMapper {
contract_val[0] -= idx0 * m_k_strides[i];
contract_val[1] -= idx1 * m_k_strides[i];
}
- EIGEN_STATIC_ASSERT(array_size<contract_t>::value > 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
+
if (side == Rhs && inner_dim_contiguous) {
eigen_assert(m_contract_strides[0] == 1);
linidx[0] += contract_val[0];
@@ -509,8 +511,6 @@ struct TensorContractionEvaluatorBase
static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
YOU_MADE_A_PROGRAMMING_MISTAKE);
- eigen_assert((internal::array_size<contract_t>::value > 0) && "Must contract on some indices");
-
DSizes<Index, LDims> eval_left_dims;
DSizes<Index, RDims> eval_right_dims;
@@ -558,7 +558,9 @@ struct TensorContractionEvaluatorBase
m_i_strides[0] = 1;
m_j_strides[0] = 1;
- m_k_strides[0] = 1;
+ if(ContractDims) {
+ m_k_strides[0] = 1;
+ }
m_i_size = 1;
m_j_size = 1;