aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-07 15:34:26 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-07 15:34:26 -0700
commitf74ab8cb8de5e425ddd25f4b06657926a2ad4599 (patch)
tree21686c69f54cd402fdf6508cedcfd25750f70898 /unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h
parent3afb640b5647654f272b1903b71877cb60ed3a78 (diff)
Add block evaluation to TensorEvalTo and fix few small bugs
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h46
1 files changed, 38 insertions, 8 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h
index bf7522682..d1e4c82d2 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h
@@ -111,22 +111,28 @@ struct TensorEvaluator<const TensorEvalToOp<ArgType, MakePointer_>, Device>
IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = true,
- BlockAccessV2 = false,
+ BlockAccessV2 = true,
PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = false, // to be implemented
RawAccess = true
};
- typedef typename internal::TensorBlock<
- CoeffReturnType, Index, internal::traits<ArgType>::NumDimensions, Layout>
- TensorBlock;
- typedef typename internal::TensorBlockReader<
- CoeffReturnType, Index, internal::traits<ArgType>::NumDimensions, Layout>
- TensorBlockReader;
+ static const int NumDims = internal::traits<ArgType>::NumDimensions;
+
+ typedef typename internal::TensorBlock<CoeffReturnType, Index, NumDims, Layout> TensorBlock;
+ typedef typename internal::TensorBlockReader<CoeffReturnType, Index, NumDims, Layout> TensorBlockReader;
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
- typedef internal::TensorBlockNotImplemented TensorBlockV2;
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ typedef typename TensorEvaluator<const ArgType, Device>::TensorBlockV2
+ ArgTensorBlock;
+
+ typedef internal::TensorBlockAssignment<
+ Scalar, NumDims, typename ArgTensorBlock::XprType, Index>
+ TensorBlockAssignment;
//===--------------------------------------------------------------------===//
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
@@ -164,6 +170,30 @@ struct TensorEvaluator<const TensorEvalToOp<ArgType, MakePointer_>, Device>
m_impl.block(&eval_to_block);
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalBlockV2(
+ TensorBlockDesc& desc, TensorBlockScratch& scratch) {
+ // Add `m_buffer` as destination buffer to the block descriptor.
+ desc.AddDestinationBuffer(
+ /*dst_base=*/m_buffer + desc.offset(),
+ /*dst_strides=*/internal::strides<Layout>(m_impl.dimensions()),
+ /*total_dst_bytes=*/
+ (internal::array_prod(m_impl.dimensions())
+ * sizeof(Scalar)));
+
+ ArgTensorBlock block = m_impl.blockV2(desc, scratch);
+
+ // If block was evaluated into a destination buffer, there is no need to do
+ // an assignment.
+ if (block.kind() != internal::TensorBlockKind::kMaterializedInOutput) {
+ TensorBlockAssignment::Run(
+ TensorBlockAssignment::target(
+ desc.dimensions(), internal::strides<Layout>(m_impl.dimensions()),
+ m_buffer, desc.offset()),
+ block.expr());
+ }
+ block.cleanup();
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_impl.cleanup();
}