aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h37
1 files changed, 37 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
index 270ad974e..29aa7a97e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
@@ -110,6 +110,8 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
TensorEvaluator<RightArgType, Device>::PacketAccess,
BlockAccess = TensorEvaluator<LeftArgType, Device>::BlockAccess &
TensorEvaluator<RightArgType, Device>::BlockAccess,
+ BlockAccessV2 = TensorEvaluator<LeftArgType, Device>::BlockAccessV2 &
+ TensorEvaluator<RightArgType, Device>::BlockAccessV2,
PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess |
TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
@@ -120,6 +122,18 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
typename internal::remove_const<Scalar>::type, Index, NumDims, Layout>
TensorBlock;
+ //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ typedef typename TensorEvaluator<const RightArgType, Device>::TensorBlockV2
+ RightTensorBlock;
+
+ typedef internal::TensorBlockAssignment<
+ Scalar, NumDims, typename RightTensorBlock::XprType, Index>
+ TensorBlockAssignment;
+ //===--------------------------------------------------------------------===//
+
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
m_leftImpl(op.lhsExpression(), device),
m_rightImpl(op.rhsExpression(), device)
@@ -214,6 +228,29 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
m_leftImpl.writeBlock(*block);
}
}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlockV2(
+ TensorBlockDesc& desc, TensorBlockScratch& scratch) {
+ if (TensorEvaluator<LeftArgType, Device>::RawAccess &&
+ m_leftImpl.data() != NULL) {
+ // If destination has raw data access, we pass it as a potential
+ // destination for a block descriptor evaluation.
+ desc.AddDestinationBuffer(
+ /*dst_base=*/m_leftImpl.data() + desc.offset(),
+ /*dst_strides=*/internal::strides<Layout>(m_leftImpl.dimensions()),
+ /*total_dst_bytes=*/
+ (internal::array_prod(m_leftImpl.dimensions()) * sizeof(Scalar)));
+ }
+
+ RightTensorBlock block = m_rightImpl.blockV2(desc, scratch);
+ // If block was evaluated into a destination, there is no need to do
+ // assignment.
+ if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) {
+ m_leftImpl.writeBlockV2(desc, block);
+ }
+ block.cleanup();
+ }
+
#ifdef EIGEN_USE_SYCL
// binding placeholder accessors to a command group handler for SYCL
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {