From ef9dfee7bdc8e0d82c9b7ddf9414ef99d866d7ba Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 24 Sep 2019 12:52:45 -0700 Subject: Tensor block evaluation V2 support for unary/binary/broadcsting --- unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h index 270ad974e..29aa7a97e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h @@ -110,6 +110,8 @@ struct TensorEvaluator, Device> TensorEvaluator::PacketAccess, BlockAccess = TensorEvaluator::BlockAccess & TensorEvaluator::BlockAccess, + BlockAccessV2 = TensorEvaluator::BlockAccessV2 & + TensorEvaluator::BlockAccessV2, PreferBlockAccess = TensorEvaluator::PreferBlockAccess | TensorEvaluator::PreferBlockAccess, Layout = TensorEvaluator::Layout, @@ -120,6 +122,18 @@ struct TensorEvaluator, Device> typename internal::remove_const::type, Index, NumDims, Layout> TensorBlock; + //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// + typedef internal::TensorBlockDescriptor TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator TensorBlockScratch; + + typedef typename TensorEvaluator::TensorBlockV2 + RightTensorBlock; + + typedef internal::TensorBlockAssignment< + Scalar, NumDims, typename RightTensorBlock::XprType, Index> + TensorBlockAssignment; + //===--------------------------------------------------------------------===// + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device) @@ -214,6 +228,29 @@ struct TensorEvaluator, Device> m_leftImpl.writeBlock(*block); } } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlockV2( + TensorBlockDesc& desc, TensorBlockScratch& scratch) { + if (TensorEvaluator::RawAccess && + m_leftImpl.data() != NULL) { + // If destination has raw data access, we pass it as a potential + // destination for a block descriptor evaluation. + desc.AddDestinationBuffer( + /*dst_base=*/m_leftImpl.data() + desc.offset(), + /*dst_strides=*/internal::strides(m_leftImpl.dimensions()), + /*total_dst_bytes=*/ + (internal::array_prod(m_leftImpl.dimensions()) * sizeof(Scalar))); + } + + RightTensorBlock block = m_rightImpl.blockV2(desc, scratch); + // If block was evaluated into a destination, there is no need to do + // assignment. + if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) { + m_leftImpl.writeBlockV2(desc, block); + } + block.cleanup(); + } + #ifdef EIGEN_USE_SYCL // binding placeholder accessors to a command group handler for SYCL EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { -- cgit v1.2.3