diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index c9d78ba9b..ab3a979a8 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -365,7 +365,8 @@ struct TensorEvaluator<const TensorReshapingOp<NewDimensions, ArgType>, Device> } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 - blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const { + blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch, + bool /*root_of_expr_ast*/ = false) const { eigen_assert(m_impl.data() != NULL); eigen_assert((kind == Runtime) || (kind == OneByN && desc.dimensions()[0] == 1) || @@ -611,7 +612,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess, - BlockAccessV2 = false, + BlockAccessV2 = TensorEvaluator<ArgType, Device>::BlockAccessV2, PreferBlockAccess = true, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = false, @@ -624,7 +625,12 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi typedef typename TensorBlock::Dimensions TensorBlockDimensions; //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// - typedef internal::TensorBlockNotImplemented TensorBlockV2; + typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; + + // Tensor slicing does not change the block type. + typedef typename TensorEvaluator<const ArgType, Device>::TensorBlockV2 + TensorBlockV2; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -804,6 +810,15 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi m_impl.block(&input_block); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 + blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch, + bool /*root_of_expr_ast*/ = false) const { + TensorBlockDesc arg_desc = desc.WithOffset(srcCoeff(desc.offset())); + TensorBlockV2 block = m_impl.blockV2(arg_desc, scratch); + if (!arg_desc.HasDestinationBuffer()) desc.DropDestinationBuffer(); + return block; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Storage::Type data() const { typename Storage::Type result = constCast(m_impl.data()); if (result) { @@ -900,7 +915,7 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device> IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess, - BlockAccessV2 = false, + BlockAccessV2 = TensorEvaluator<ArgType, Device>::BlockAccessV2, PreferBlockAccess = true, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = false, @@ -913,7 +928,8 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device> typedef typename TensorBlock::Dimensions TensorBlockDimensions; //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// - typedef internal::TensorBlockNotImplemented TensorBlockV2; + typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -987,6 +1003,13 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device> block.block_strides(), TensorBlockDimensions(this->m_inputStrides), const_cast<ScalarNoConst*>(block.data()))); } + + template<typename TensorBlockV2> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writeBlockV2( + const TensorBlockDesc& desc, const TensorBlockV2& block) { + TensorBlockDesc arg_desc = desc.WithOffset(this->srcCoeff(desc.offset())); + this->m_impl.writeBlockV2(arg_desc, block); + } }; namespace internal { |