diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-07 15:34:26 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-07 15:34:26 -0700 |
commit | f74ab8cb8de5e425ddd25f4b06657926a2ad4599 (patch) | |
tree | 21686c69f54cd402fdf6508cedcfd25750f70898 /unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h | |
parent | 3afb640b5647654f272b1903b71877cb60ed3a78 (diff) |
Add block evaluation to TensorEvalTo and fix few small bugs
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h | 46 |
1 files changed, 38 insertions, 8 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h index bf7522682..d1e4c82d2 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h @@ -111,22 +111,28 @@ struct TensorEvaluator<const TensorEvalToOp<ArgType, MakePointer_>, Device> IsAligned = TensorEvaluator<ArgType, Device>::IsAligned, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, BlockAccess = true, - BlockAccessV2 = false, + BlockAccessV2 = true, PreferBlockAccess = false, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = false, // to be implemented RawAccess = true }; - typedef typename internal::TensorBlock< - CoeffReturnType, Index, internal::traits<ArgType>::NumDimensions, Layout> - TensorBlock; - typedef typename internal::TensorBlockReader< - CoeffReturnType, Index, internal::traits<ArgType>::NumDimensions, Layout> - TensorBlockReader; + static const int NumDims = internal::traits<ArgType>::NumDimensions; + + typedef typename internal::TensorBlock<CoeffReturnType, Index, NumDims, Layout> TensorBlock; + typedef typename internal::TensorBlockReader<CoeffReturnType, Index, NumDims, Layout> TensorBlockReader; //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// - typedef internal::TensorBlockNotImplemented TensorBlockV2; + typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; + + typedef typename TensorEvaluator<const ArgType, Device>::TensorBlockV2 + ArgTensorBlock; + + typedef internal::TensorBlockAssignment< + Scalar, NumDims, typename ArgTensorBlock::XprType, Index> + TensorBlockAssignment; //===--------------------------------------------------------------------===// EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -164,6 +170,30 @@ struct TensorEvaluator<const TensorEvalToOp<ArgType, MakePointer_>, Device> m_impl.block(&eval_to_block); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlockV2( + TensorBlockDesc& desc, TensorBlockScratch& scratch) { + // Add `m_buffer` as destination buffer to the block descriptor. + desc.AddDestinationBuffer( + /*dst_base=*/m_buffer + desc.offset(), + /*dst_strides=*/internal::strides<Layout>(m_impl.dimensions()), + /*total_dst_bytes=*/ + (internal::array_prod(m_impl.dimensions()) + * sizeof(Scalar))); + + ArgTensorBlock block = m_impl.blockV2(desc, scratch); + + // If block was evaluated into a destination buffer, there is no need to do + // an assignment. + if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) { + TensorBlockAssignment::Run( + TensorBlockAssignment::target( + desc.dimensions(), internal::strides<Layout>(m_impl.dimensions()), + m_buffer, desc.offset()), + block.expr()); + } + block.cleanup(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } |