aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
diff options
context:
space:
mode:
authorGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-02-22 16:36:24 +0000
committerGravatar Mehdi Goli <mehdi.goli@codeplay.com>2017-02-22 16:36:24 +0000
commit89dfd51fae868393b66b1949638e03de2ba17c1f (patch)
tree6c20c89d7b8fc27639f1ae25af7790d22f37892e /unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
parent4f07ac16b0722597c55e2783cee33606a1f5e390 (diff)
Adding Sycl Backend for TensorGenerator.h.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h13
1 files changed, 10 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
index eb1d4934e..ca87f493a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
@@ -97,10 +97,9 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
- : m_generator(op.generator())
+ : m_generator(op.generator()), m_argImpl(op.expression(), device)
{
- TensorEvaluator<ArgType, Device> impl(op.expression(), device);
- m_dimensions = impl.dimensions();
+ m_dimensions = m_argImpl.dimensions();
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
m_strides[0] = 1;
@@ -155,6 +154,12 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
+ /// required by sycl in order to extract the accessor
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorEvaluator<ArgType, Device>& impl() const { return m_argImpl; }
+ /// required by sycl in order to extract the accessor
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Generator& functor() const { return m_generator; }
+
+
protected:
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void extract_coordinates(Index index, array<Index, NumDims>& coords) const {
@@ -178,6 +183,8 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
Dimensions m_dimensions;
array<Index, NumDims> m_strides;
Generator m_generator;
+ // required by sycl
+ TensorEvaluator<ArgType, Device> m_argImpl;
};
} // end namespace Eigen