diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h index 270ad974e..29aa7a97e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h @@ -110,6 +110,8 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> TensorEvaluator<RightArgType, Device>::PacketAccess, BlockAccess = TensorEvaluator<LeftArgType, Device>::BlockAccess & TensorEvaluator<RightArgType, Device>::BlockAccess, + BlockAccessV2 = TensorEvaluator<LeftArgType, Device>::BlockAccessV2 & + TensorEvaluator<RightArgType, Device>::BlockAccessV2, PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess | TensorEvaluator<RightArgType, Device>::PreferBlockAccess, Layout = TensorEvaluator<LeftArgType, Device>::Layout, @@ -120,6 +122,18 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> typename internal::remove_const<Scalar>::type, Index, NumDims, Layout> TensorBlock; + //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// + typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc; + typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch; + + typedef typename TensorEvaluator<const RightArgType, Device>::TensorBlockV2 + RightTensorBlock; + + typedef internal::TensorBlockAssignment< + Scalar, NumDims, typename RightTensorBlock::XprType, Index> + TensorBlockAssignment; + //===--------------------------------------------------------------------===// + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device) @@ -214,6 +228,29 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> m_leftImpl.writeBlock(*block); } } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlockV2( + TensorBlockDesc& desc, TensorBlockScratch& scratch) { + if (TensorEvaluator<LeftArgType, Device>::RawAccess && + m_leftImpl.data() != NULL) { + // If destination has raw data access, we pass it as a potential + // destination for a block descriptor evaluation. + desc.AddDestinationBuffer( + /*dst_base=*/m_leftImpl.data() + desc.offset(), + /*dst_strides=*/internal::strides<Layout>(m_leftImpl.dimensions()), + /*total_dst_bytes=*/ + (internal::array_prod(m_leftImpl.dimensions()) * sizeof(Scalar))); + } + + RightTensorBlock block = m_rightImpl.blockV2(desc, scratch); + // If block was evaluated into a destination, there is no need to do + // assignment. + if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) { + m_leftImpl.writeBlockV2(desc, block); + } + block.cleanup(); + } + #ifdef EIGEN_USE_SYCL // binding placeholder accessors to a command group handler for SYCL EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { |