aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-10-15 16:52:33 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-10-15 16:52:33 -0700
commit900c7c61bb6abca5b3324c11ba1b45fa3e31c5fa (patch)
tree962268a4470499f765a4d7a716495cd6cb6ce80c /unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
parentd835a0bf539e2827502f3d7ddcb1033baf05ecd4 (diff)
Check if it's allowed to squueze inner dimensions in TensorBlockIO
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h36
1 files changed, 32 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
index a59a5d5b2..91c77b05a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
@@ -246,6 +246,8 @@ class TensorBlockIO {
typedef TensorBlockCopyOp<Scalar, StorageIndex> BlockCopyOp;
protected:
+ typedef array<StorageIndex, NumDims> Dimensions;
+
struct BlockIteratorState {
StorageIndex input_stride;
StorageIndex output_stride;
@@ -262,22 +264,46 @@ class TensorBlockIO {
count(0) {}
};
+ // Compute how many inner dimensions it's allowed to squeeze when doing IO
+ // between a tensor and a block. It's safe to squeeze inner dimensions, only
+ // if they are not reordered.
+ static int NumSqueezableInnerDims(const Dimensions& tensor_to_block_dim_map) {
+ int num_squeezable_dims = 0;
+ if (Layout == ColMajor) {
+ for (int i = 0; i < NumDims; ++i) {
+ if (tensor_to_block_dim_map[i] == i) num_squeezable_dims++;
+ else break;
+ }
+ } else {
+ for (int i = NumDims - 1; i >= 0; --i) {
+ if (tensor_to_block_dim_map[i] == i) num_squeezable_dims++;
+ else break;
+ }
+ }
+ return num_squeezable_dims;
+ }
+
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Copy(
const Block& block, StorageIndex first_coeff_index,
- const array<StorageIndex, NumDims>& tensor_to_block_dim_map,
- const array<StorageIndex, NumDims>& tensor_strides, const Scalar* src_data,
+ const Dimensions& tensor_to_block_dim_map,
+ const Dimensions& tensor_strides,
+ const Scalar* src_data,
Scalar* dst_data) {
+ // Do not squeeze reordered inner dimensions.
+ int num_squeezable_dims = NumSqueezableInnerDims(tensor_to_block_dim_map);
+
// Find the innermost tensor dimension whose size is not 1. This is the
// effective inner dim. If all dimensions are of size 1, then fallback to
// using the actual innermost dim to avoid out-of-bound access.
StorageIndex num_size_one_inner_dims = 0;
- for (int i = 0; i < NumDims; ++i) {
+ for (int i = 0; i < num_squeezable_dims; ++i) {
const int dim = cond<Layout>()(i, NumDims - i - 1);
if (block.block_sizes()[tensor_to_block_dim_map[dim]] != 1) {
num_size_one_inner_dims = i;
break;
}
}
+
// Calculate strides and dimensions.
const StorageIndex tensor_stride1_dim = cond<Layout>()(
num_size_one_inner_dims, NumDims - num_size_one_inner_dims - 1);
@@ -286,7 +312,9 @@ class TensorBlockIO {
StorageIndex block_inner_dim_size =
NumDims == 0 ? 1
: block.block_sizes()[block_dim_for_tensor_stride1_dim];
- for (Index i = num_size_one_inner_dims + 1; i < NumDims; ++i) {
+
+ // Squeeze multiple inner dims into one for larger inner dim size.
+ for (Index i = num_size_one_inner_dims + 1; i < num_squeezable_dims; ++i) {
const Index dim = cond<Layout>()(i, NumDims - i - 1);
const StorageIndex block_stride =
block.block_strides()[tensor_to_block_dim_map[dim]];