aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-14 14:31:59 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-14 14:31:59 -0700
commitd380c23b2cc0b02e10819e779c73cde2c62603b2 (patch)
tree09d204c87ed6a9f55aa0c6305d7e4199a71dbf8a /unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
parent39fb9eeccf2e79542acad9bbf5196e462c1b2cee (diff)
Block evaluation for TensorGenerator/TensorReverse/TensorShuffling
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h33
1 files changed, 28 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
index c9d78ba9b..ab3a979a8 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
@@ -365,7 +365,8 @@ struct TensorEvaluator<const TensorReshapingOp<NewDimensions, ArgType>, Device>
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
- blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch) const {
+ blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch,
+ bool /*root_of_expr_ast*/ = false) const {
eigen_assert(m_impl.data() != NULL);
eigen_assert((kind == Runtime) ||
(kind == OneByN && desc.dimensions()[0] == 1) ||
@@ -611,7 +612,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
- BlockAccessV2 = false,
+ BlockAccessV2 = TensorEvaluator<ArgType, Device>::BlockAccessV2,
PreferBlockAccess = true,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = false,
@@ -624,7 +625,12 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
typedef typename TensorBlock::Dimensions TensorBlockDimensions;
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
- typedef internal::TensorBlockNotImplemented TensorBlockV2;
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
+
+ // Tensor slicing does not change the block type.
+ typedef typename TensorEvaluator<const ArgType, Device>::TensorBlockV2
+ TensorBlockV2;
//===--------------------------------------------------------------------===//
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
@@ -804,6 +810,15 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
m_impl.block(&input_block);
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
+ blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch,
+ bool /*root_of_expr_ast*/ = false) const {
+ TensorBlockDesc arg_desc = desc.WithOffset(srcCoeff(desc.offset()));
+ TensorBlockV2 block = m_impl.blockV2(arg_desc, scratch);
+ if (!arg_desc.HasDestinationBuffer()) desc.DropDestinationBuffer();
+ return block;
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Storage::Type data() const {
typename Storage::Type result = constCast(m_impl.data());
if (result) {
@@ -900,7 +915,7 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
- BlockAccessV2 = false,
+ BlockAccessV2 = TensorEvaluator<ArgType, Device>::BlockAccessV2,
PreferBlockAccess = true,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = false,
@@ -913,7 +928,8 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
typedef typename TensorBlock::Dimensions TensorBlockDimensions;
//===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
- typedef internal::TensorBlockNotImplemented TensorBlockV2;
+ typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
+ typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
//===--------------------------------------------------------------------===//
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
@@ -987,6 +1003,13 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
block.block_strides(), TensorBlockDimensions(this->m_inputStrides),
const_cast<ScalarNoConst*>(block.data())));
}
+
+ template<typename TensorBlockV2>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writeBlockV2(
+ const TensorBlockDesc& desc, const TensorBlockV2& block) {
+ TensorBlockDesc arg_desc = desc.WithOffset(this->srcCoeff(desc.offset()));
+ this->m_impl.writeBlockV2(arg_desc, block);
+ }
};
namespace internal {