aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h6
1 files changed, 4 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
index b756be3b3..ba5ab1396 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
@@ -115,6 +115,7 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable,
const DefaultDevice& device = DefaultDevice()) {
typedef TensorBlock<ScalarNoConst, StorageIndex, NumDims, Evaluator::Layout> TensorBlock;
typedef TensorBlockMapper<ScalarNoConst, StorageIndex, NumDims, Evaluator::Layout> TensorBlockMapper;
+ typedef typename TensorBlock::Dimensions TensorBlockDimensions;
Evaluator evaluator(expr, device);
Index total_size = array_prod(evaluator.dimensions());
@@ -138,8 +139,9 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable,
evaluator.getResourceRequirements(&resources);
MergeResourceRequirements(resources, &block_shape, &block_total_size);
- TensorBlockMapper block_mapper(evaluator.dimensions(), block_shape,
- block_total_size);
+ TensorBlockMapper block_mapper(
+ TensorBlockDimensions(evaluator.dimensions()), block_shape,
+ block_total_size);
block_total_size = block_mapper.block_dims_total_size();
Scalar* data = static_cast<Scalar*>(