diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h index 998757d14..755170a34 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h @@ -134,13 +134,22 @@ struct TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device> } } +#ifdef EIGEN_USE_SYCL + // binding placeholder accessors to a command group handler for SYCL + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { + m_impl.bind(cgh); + } +#endif + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; + typedef StorageMemory<CoeffReturnType, Device> Storage; + typedef typename Storage::Type EvaluatorPointerType; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) { return m_impl.evalSubExprsIfNeeded(data); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { @@ -162,7 +171,9 @@ struct TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device> return m_impl.costPerCoeff(vectorized); } - EIGEN_DEVICE_FUNC typename Eigen::internal::traits<XprType>::PointerType data() const { return m_impl.data(); } + EIGEN_DEVICE_FUNC typename Storage::Type data() const { + return constCast(m_impl.data()); + } const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; } @@ -192,7 +203,7 @@ template<typename ArgType, typename Device> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) { } - + typedef typename XprType::Index Index; typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; |