diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h | 55 |
1 files changed, 44 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index a8160e17e..cc3e67677 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -294,23 +294,45 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device> typedef typename Storage::Type EvaluatorPointerType; enum { - IsAligned = false, - PacketAccess = + IsAligned = false, + PacketAccess = #ifndef EIGEN_USE_SYCL - true, + true, #else - TensorEvaluator<ArgType, Device>::PacketAccess & - internal::type_casting_traits<SrcType, TargetType>::VectorizedCast, + TensorEvaluator<ArgType, Device>::PacketAccess & + internal::type_casting_traits<SrcType, TargetType>::VectorizedCast, #endif - BlockAccess = false, - BlockAccessV2 = false, - PreferBlockAccess = false, - Layout = TensorEvaluator<ArgType, Device>::Layout, - RawAccess = false + BlockAccess = false, + BlockAccessV2 = TensorEvaluator<ArgType, Device>::BlockAccessV2, + PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess, + Layout = TensorEvaluator<ArgType, Device>::Layout, + RawAccess = false }; + static const int NumDims = internal::array_size<Dimensions>::value; + //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// - typedef internal::TensorBlockNotImplemented TensorBlockV2; + typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; + + typedef typename TensorEvaluator<const ArgType, Device>::TensorBlockV2 + ArgTensorBlock; + + struct TensorConversionOpBlockFactory { + template <typename ArgXprType> + struct XprType { + typedef TensorConversionOp<TargetType, const ArgXprType> type; + }; + + template <typename ArgXprType> + typename XprType<ArgXprType>::type expr(const ArgXprType& expr) const { + return typename XprType<ArgXprType>::type(expr); + } + }; + + typedef internal::TensorUnaryExprBlock<TensorConversionOpBlockFactory, + ArgTensorBlock> + TensorBlockV2; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -376,6 +398,17 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device> } } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void getResourceRequirements( + std::vector<internal::TensorOpResourceRequirements>* resources) const { + m_impl.getResourceRequirements(resources); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 + blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const { + return TensorBlockV2(m_impl.blockV2(desc, scratch), + TensorConversionOpBlockFactory()); + } + EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; } /// required by sycl in order to extract the sycl accessor |