diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h index 221f8e843..6cacf1cc1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h @@ -58,6 +58,60 @@ EIGEN_STRONG_INLINE DSizes<std::ptrdiff_t, sizeof...(Indices)> strides( } // -------------------------------------------------------------------------- // + +// Tensor block shape type defines what are the shape preference for the blocks +// extracted from the larger tensor. +// +// Example: blocks of 100 elements from the large 100x100 tensor: +// - tensor: 100x100 +// - target_block_size: 100 +// +// TensorBlockShapeType: +// - kUniformAllDims: 100 blocks of size 10x10 +// - kSkewedInnerDims: 100 blocks of size 100x1 (or 1x100 depending on a column +// or row major layout) +enum class TensorBlockV2ShapeType { kUniformAllDims, kSkewedInnerDims }; + +struct TensorBlockV2ResourceRequirements { + TensorBlockV2ShapeType shape_type; + size_t size; + + TensorBlockShapeType shapeV1() const { + return shape_type == TensorBlockV2ShapeType::kUniformAllDims + ? internal::kUniformAllDims + : internal::kSkewedInnerDims; + } + + static TensorBlockV2ResourceRequirements + merge(const TensorBlockV2ResourceRequirements &lhs, + const TensorBlockV2ResourceRequirements &rhs) { + return {merge(lhs.shape_type, rhs.shape_type), merge(rhs.size, lhs.size)}; + } + + // This is a resource requirement that should be returned from expressions + // that do not have any block evaluation preference (e.g. default tensor + // expression with raw buffer access). + static TensorBlockV2ResourceRequirements any() { + return {TensorBlockV2ShapeType::kUniformAllDims, 1}; + } + +private: + using Requirements = TensorBlockV2ResourceRequirements; + + static size_t merge(size_t lhs_size, size_t rhs_size) { + return numext::maxi(lhs_size, rhs_size); + } + + static TensorBlockV2ShapeType merge(TensorBlockV2ShapeType lhs, + TensorBlockV2ShapeType rhs) { + return (lhs == TensorBlockV2ShapeType::kSkewedInnerDims || + rhs == TensorBlockV2ShapeType::kSkewedInnerDims) + ? TensorBlockV2ShapeType::kSkewedInnerDims + : TensorBlockV2ShapeType::kUniformAllDims; + } +}; + +// -------------------------------------------------------------------------- // // TensorBlockDescriptor specifies a block offset within a tensor and the block // sizes along each of the tensor dimensions. |