diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index b90791d8d..93bab11b1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -245,8 +245,8 @@ TensorExecutorTilingContext<TensorBlockMapper> GetTensorExecutorTilingContext( evaluator.getResourceRequirements(); // Update target block size based on cost model. - TensorOpCost cost = evaluator.costPerCoeff(Vectorizable); - double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(1, cost); + double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize( + 1, requirements.cost_per_coeff); requirements.size = static_cast<size_t>(1.0 / taskSize); TensorBlockMapper block_mapper( @@ -259,7 +259,8 @@ TensorExecutorTilingContext<TensorBlockMapper> GetTensorExecutorTilingContext( align * divup<size_t>(block_size * sizeof(typename Evaluator::Scalar), align); - return {block_mapper, cost * block_size, aligned_blocksize}; + return {block_mapper, requirements.cost_per_coeff * block_size, + aligned_blocksize}; } template <typename Evaluator, typename StorageIndex, bool Vectorizable> |