aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-18 14:21:01 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-18 14:21:01 -0700
commit79d4129cce4055955e7a340822df31b6f8d25edc (patch)
treec7a45a502784718dc1ef2e86b6ffcc764cb47bd9 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent6e5a3b898fde197d4748315694103f577c0f503f (diff)
Specify default output kernel for TensorContractionOp
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h6
1 files changed, 3 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index c7c443a59..0e69cd40c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -182,7 +182,7 @@ struct NoOpOutputKernel {
Index num_cols) const {}
};
-template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
+template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel>
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors>
{
public:
@@ -507,7 +507,7 @@ struct TensorContractionEvaluatorBase
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
- #if !defined(EIGEN_HIPCC)
+ #if !defined(EIGEN_HIPCC)
EIGEN_DEVICE_FUNC
#endif
void evalGemv(Scalar* buffer) const {
@@ -556,7 +556,7 @@ struct TensorContractionEvaluatorBase
}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
- #if !defined(EIGEN_HIPCC)
+ #if !defined(EIGEN_HIPCC)
EIGEN_DEVICE_FUNC
#endif
void evalGemm(Scalar* buffer) const {