aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h54
1 files changed, 54 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
index 221f8e843..6cacf1cc1 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
@@ -58,6 +58,60 @@ EIGEN_STRONG_INLINE DSizes<std::ptrdiff_t, sizeof...(Indices)> strides(
}
// -------------------------------------------------------------------------- //
+
+// Tensor block shape type defines what are the shape preference for the blocks
+// extracted from the larger tensor.
+//
+// Example: blocks of 100 elements from the large 100x100 tensor:
+// - tensor: 100x100
+// - target_block_size: 100
+//
+// TensorBlockShapeType:
+// - kUniformAllDims: 100 blocks of size 10x10
+// - kSkewedInnerDims: 100 blocks of size 100x1 (or 1x100 depending on a column
+// or row major layout)
+enum class TensorBlockV2ShapeType { kUniformAllDims, kSkewedInnerDims };
+
+struct TensorBlockV2ResourceRequirements {
+ TensorBlockV2ShapeType shape_type;
+ size_t size;
+
+ TensorBlockShapeType shapeV1() const {
+ return shape_type == TensorBlockV2ShapeType::kUniformAllDims
+ ? internal::kUniformAllDims
+ : internal::kSkewedInnerDims;
+ }
+
+ static TensorBlockV2ResourceRequirements
+ merge(const TensorBlockV2ResourceRequirements &lhs,
+ const TensorBlockV2ResourceRequirements &rhs) {
+ return {merge(lhs.shape_type, rhs.shape_type), merge(rhs.size, lhs.size)};
+ }
+
+ // This is a resource requirement that should be returned from expressions
+ // that do not have any block evaluation preference (e.g. default tensor
+ // expression with raw buffer access).
+ static TensorBlockV2ResourceRequirements any() {
+ return {TensorBlockV2ShapeType::kUniformAllDims, 1};
+ }
+
+private:
+ using Requirements = TensorBlockV2ResourceRequirements;
+
+ static size_t merge(size_t lhs_size, size_t rhs_size) {
+ return numext::maxi(lhs_size, rhs_size);
+ }
+
+ static TensorBlockV2ShapeType merge(TensorBlockV2ShapeType lhs,
+ TensorBlockV2ShapeType rhs) {
+ return (lhs == TensorBlockV2ShapeType::kSkewedInnerDims ||
+ rhs == TensorBlockV2ShapeType::kSkewedInnerDims)
+ ? TensorBlockV2ShapeType::kSkewedInnerDims
+ : TensorBlockV2ShapeType::kUniformAllDims;
+ }
+};
+
+// -------------------------------------------------------------------------- //
// TensorBlockDescriptor specifies a block offset within a tensor and the block
// sizes along each of the tensor dimensions.