From d380c23b2cc0b02e10819e779c73cde2c62603b2 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 14 Oct 2019 14:31:59 -0700 Subject: Block evaluation for TensorGenerator/TensorReverse/TensorShuffling --- .../Eigen/CXX11/src/Tensor/TensorMorphing.h | 33 ++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index c9d78ba9b..ab3a979a8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -365,7 +365,8 @@ struct TensorEvaluator, Device> } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 - blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const { + blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch, + bool /*root_of_expr_ast*/ = false) const { eigen_assert(m_impl.data() != NULL); eigen_assert((kind == Runtime) || (kind == OneByN && desc.dimensions()[0] == 1) || @@ -611,7 +612,7 @@ struct TensorEvaluator, Devi IsAligned = false, PacketAccess = TensorEvaluator::PacketAccess, BlockAccess = TensorEvaluator::BlockAccess, - BlockAccessV2 = false, + BlockAccessV2 = TensorEvaluator::BlockAccessV2, PreferBlockAccess = true, Layout = TensorEvaluator::Layout, CoordAccess = false, @@ -624,7 +625,12 @@ struct TensorEvaluator, Devi typedef typename TensorBlock::Dimensions TensorBlockDimensions; //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// - typedef internal::TensorBlockNotImplemented TensorBlockV2; + typedef internal::TensorBlockDescriptor TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator TensorBlockScratch; + + // Tensor slicing does not change the block type. + typedef typename TensorEvaluator::TensorBlockV2 + TensorBlockV2; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -804,6 +810,15 @@ struct TensorEvaluator, Devi m_impl.block(&input_block); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 + blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch, + bool /*root_of_expr_ast*/ = false) const { + TensorBlockDesc arg_desc = desc.WithOffset(srcCoeff(desc.offset())); + TensorBlockV2 block = m_impl.blockV2(arg_desc, scratch); + if (!arg_desc.HasDestinationBuffer()) desc.DropDestinationBuffer(); + return block; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Storage::Type data() const { typename Storage::Type result = constCast(m_impl.data()); if (result) { @@ -900,7 +915,7 @@ struct TensorEvaluator, Device> IsAligned = false, PacketAccess = TensorEvaluator::PacketAccess, BlockAccess = TensorEvaluator::BlockAccess, - BlockAccessV2 = false, + BlockAccessV2 = TensorEvaluator::BlockAccessV2, PreferBlockAccess = true, Layout = TensorEvaluator::Layout, CoordAccess = false, @@ -913,7 +928,8 @@ struct TensorEvaluator, Device> typedef typename TensorBlock::Dimensions TensorBlockDimensions; //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// - typedef internal::TensorBlockNotImplemented TensorBlockV2; + typedef internal::TensorBlockDescriptor TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator TensorBlockScratch; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -987,6 +1003,13 @@ struct TensorEvaluator, Device> block.block_strides(), TensorBlockDimensions(this->m_inputStrides), const_cast(block.data()))); } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writeBlockV2( + const TensorBlockDesc& desc, const TensorBlockV2& block) { + TensorBlockDesc arg_desc = desc.WithOffset(this->srcCoeff(desc.offset())); + this->m_impl.writeBlockV2(arg_desc, block); + } }; namespace internal { -- cgit v1.2.3