aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Mehdi Goli <mehdi.goli@codeplay.com>2016-12-14 15:30:37 +0000
committerGravatar Mehdi Goli <mehdi.goli@codeplay.com>2016-12-14 15:30:37 +0000
commit2d4a091beb9e55664c1475137af7166d524cbc1d (patch)
treed9e4baec0be3eb3c8a4bb2451701f7e49730daa1 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent3d59a477201d4d4f34b4332fda699c21387cf726 (diff)
Adding tensor contraction operation backend for Sycl; adding test for contractionOp sycl backend; adding temporary solution to prevent memory leak in buffer; cleaning up cxx11_tensor_buildins_sycl.h
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h10
1 files changed, 7 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 20b29e5fd..2ac6abf69 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -156,9 +156,9 @@ struct TensorContractionEvaluatorBase
m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
op.rhsExpression(), op.lhsExpression()), device),
m_device(device),
- m_result(NULL) {
+ m_result(NULL), m_expr_indices(op.indices()) {
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
- static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
+ static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
YOU_MADE_A_PROGRAMMING_MISTAKE);
@@ -327,7 +327,7 @@ struct TensorContractionEvaluatorBase
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar * data) {
m_leftImpl.evalSubExprsIfNeeded(NULL);
m_rightImpl.evalSubExprsIfNeeded(NULL);
if (data) {
@@ -564,6 +564,9 @@ struct TensorContractionEvaluatorBase
TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
const Device& m_device;
Scalar* m_result;
+ /// required for sycl
+ const Indices m_expr_indices;
+
};
@@ -621,6 +624,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
}
+
};
} // end namespace Eigen