aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h41
1 files changed, 17 insertions, 24 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 6f113b903..20b29e5fd 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -25,8 +25,9 @@ template<typename Dimensions, typename LhsXprType, typename RhsXprType>
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
{
// Type promotion to handle the case where the types of the lhs and the rhs are different.
- typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
- typename RhsXprType::Scalar>::ret Scalar;
+ typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
+ typename remove_const<typename RhsXprType::Scalar>::type>::ResScalar Scalar;
+
typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
typedef typename promote_index_type<typename traits<LhsXprType>::Index,
@@ -37,7 +38,7 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
typedef typename remove_reference<RhsNested>::type _RhsNested;
// From NumDims below.
- static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size;
+ static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
static const int Layout = traits<LhsXprType>::Layout;
enum {
@@ -65,7 +66,7 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
typedef Device_ Device;
// From NumDims below.
- static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size;
+ static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
};
} // end namespace internal
@@ -75,8 +76,8 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
{
public:
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
- typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
- typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
+ typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType,
+ typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType;
typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
@@ -140,11 +141,11 @@ struct TensorContractionEvaluatorBase
static const int RDims =
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
static const int ContractDims = internal::array_size<Indices>::value;
- static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
+ static const int NumDims = LDims + RDims - 2 * ContractDims;
typedef array<Index, ContractDims> contract_t;
- typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
- typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
+ typedef array<Index, LDims - ContractDims> left_nocontract_t;
+ typedef array<Index, RDims - ContractDims> right_nocontract_t;
typedef DSizes<Index, NumDims> Dimensions;
@@ -218,11 +219,9 @@ struct TensorContractionEvaluatorBase
rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
}
- m_i_strides[0] = 1;
- m_j_strides[0] = 1;
- if(ContractDims) {
- m_k_strides[0] = 1;
- }
+ if (m_i_strides.size() > 0) m_i_strides[0] = 1;
+ if (m_j_strides.size() > 0) m_j_strides[0] = 1;
+ if (m_k_strides.size() > 0) m_k_strides[0] = 1;
m_i_size = 1;
m_j_size = 1;
@@ -318,11 +317,6 @@ struct TensorContractionEvaluatorBase
}
}
- // Scalar case. We represent the result as a 1d tensor of size 1.
- if (LDims + RDims == 2 * ContractDims) {
- m_dimensions[0] = 1;
- }
-
// If the layout is RowMajor, we need to reverse the m_dimensions
if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
@@ -510,7 +504,7 @@ struct TensorContractionEvaluatorBase
// call gebp (matrix kernel)
// The parameters here are copied from Eigen's GEMM implementation
- gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, 1.0, -1, -1, 0, 0);
+ gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
}
}
}
@@ -607,15 +601,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
static const int ContractDims = internal::array_size<Indices>::value;
typedef array<Index, ContractDims> contract_t;
- typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
- typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
+ typedef array<Index, LDims - ContractDims> left_nocontract_t;
+ typedef array<Index, RDims - ContractDims> right_nocontract_t;
- static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
+ static const int NumDims = LDims + RDims - 2 * ContractDims;
// Could we use NumDimensions here?
typedef DSizes<Index, NumDims> Dimensions;
-
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) { }