aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-09-24 12:52:45 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-09-24 12:52:45 -0700
commitef9dfee7bdc8e0d82c9b7ddf9414ef99d866d7ba (patch)
tree490a8ae1f247cf226475f504ea1d3ab305b98097 /unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
parentefd9867ff0e8df23016ac6c9828d0d7bf8bec1b1 (diff)
Tensor block evaluation V2 support for unary/binary/broadcsting
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h37
1 files changed, 37 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
index 270ad974e..29aa7a97e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h
@@ -110,6 +110,8 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
TensorEvaluator<RightArgType, Device>::PacketAccess,
BlockAccess = TensorEvaluator<LeftArgType, Device>::BlockAccess &
TensorEvaluator<RightArgType, Device>::BlockAccess,
+ BlockAccessV2 = TensorEvaluator<LeftArgType, Device>::BlockAccessV2 &
+ TensorEvaluator<RightArgType, Device>::BlockAccessV2,
PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess |
TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
@@ -120,6 +122,18 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
typename internal::remove_const<Scalar>::type, Index, NumDims, Layout>
TensorBlock;
+ //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ typedef typename TensorEvaluator<const RightArgType, Device>::TensorBlockV2
+ RightTensorBlock;
+
+ typedef internal::TensorBlockAssignment<
+ Scalar, NumDims, typename RightTensorBlock::XprType, Index>
+ TensorBlockAssignment;
+ //===--------------------------------------------------------------------===//
+
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
m_leftImpl(op.lhsExpression(), device),
m_rightImpl(op.rhsExpression(), device)
@@ -214,6 +228,29 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device>
m_leftImpl.writeBlock(*block);
}
}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlockV2(
+ TensorBlockDesc& desc, TensorBlockScratch& scratch) {
+ if (TensorEvaluator<LeftArgType, Device>::RawAccess &&
+ m_leftImpl.data() != NULL) {
+ // If destination has raw data access, we pass it as a potential
+ // destination for a block descriptor evaluation.
+ desc.AddDestinationBuffer(
+ /*dst_base=*/m_leftImpl.data() + desc.offset(),
+ /*dst_strides=*/internal::strides<Layout>(m_leftImpl.dimensions()),
+ /*total_dst_bytes=*/
+ (internal::array_prod(m_leftImpl.dimensions()) * sizeof(Scalar)));
+ }
+
+ RightTensorBlock block = m_rightImpl.blockV2(desc, scratch);
+ // If block was evaluated into a destination, there is no need to do
+ // assignment.
+ if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) {
+ m_leftImpl.writeBlockV2(desc, block);
+ }
+ block.cleanup();
+ }
+
#ifdef EIGEN_USE_SYCL
// binding placeholder accessors to a command group handler for SYCL
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {