diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h index 7738f18fb..839c6e3e5 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h @@ -260,6 +260,21 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy return rslt; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost + costPerCoeff(bool vectorized) const { + const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + + 2 * TensorOpCost::MulCost<Index>() + + TensorOpCost::DivCost<Index>() + + TensorOpCost::ModCost<Index>()); + const double lhs_size = m_leftImpl.dimensions().TotalSize(); + const double rhs_size = m_rightImpl.dimensions().TotalSize(); + return (lhs_size / (lhs_size + rhs_size)) * + m_leftImpl.costPerCoeff(vectorized) + + (rhs_size / (lhs_size + rhs_size)) * + m_rightImpl.costPerCoeff(vectorized) + + TensorOpCost(0, 0, compute_cost); + } + EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } protected: |