From 60ae24ee1a6c16114de456d77fcfba6f5a1160ca Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 2 Oct 2019 12:44:06 -0700 Subject: Add block evaluation to TensorReshaping/TensorCasting/TensorPadding/TensorSelect --- .../Eigen/CXX11/src/Tensor/TensorConversion.h | 55 +++++++++++++++++----- 1 file changed, 44 insertions(+), 11 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index a8160e17e..cc3e67677 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -294,23 +294,45 @@ struct TensorEvaluator, Device> typedef typename Storage::Type EvaluatorPointerType; enum { - IsAligned = false, - PacketAccess = + IsAligned = false, + PacketAccess = #ifndef EIGEN_USE_SYCL - true, + true, #else - TensorEvaluator::PacketAccess & - internal::type_casting_traits::VectorizedCast, + TensorEvaluator::PacketAccess & + internal::type_casting_traits::VectorizedCast, #endif - BlockAccess = false, - BlockAccessV2 = false, - PreferBlockAccess = false, - Layout = TensorEvaluator::Layout, - RawAccess = false + BlockAccess = false, + BlockAccessV2 = TensorEvaluator::BlockAccessV2, + PreferBlockAccess = TensorEvaluator::PreferBlockAccess, + Layout = TensorEvaluator::Layout, + RawAccess = false }; + static const int NumDims = internal::array_size::value; + //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// - typedef internal::TensorBlockNotImplemented TensorBlockV2; + typedef internal::TensorBlockDescriptor TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator TensorBlockScratch; + + typedef typename TensorEvaluator::TensorBlockV2 + ArgTensorBlock; + + struct TensorConversionOpBlockFactory { + template + struct XprType { + typedef TensorConversionOp type; + }; + + template + typename XprType::type expr(const ArgXprType& expr) const { + return typename XprType::type(expr); + } + }; + + typedef internal::TensorUnaryExprBlock + TensorBlockV2; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -376,6 +398,17 @@ struct TensorEvaluator, Device> } } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void getResourceRequirements( + std::vector* resources) const { + m_impl.getResourceRequirements(resources); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 + blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const { + return TensorBlockV2(m_impl.blockV2(desc, scratch), + TensorConversionOpBlockFactory()); + } + EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; } /// required by sycl in order to extract the sycl accessor -- cgit v1.2.3