diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-12-10 11:58:30 -0800 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-12-10 14:31:44 -0800 |
commit | dbca11e8805ec07660d8f966a1884ad0be302f15 (patch) | |
tree | 9da1438132a9a40de7ca3abafec2e559eb0449e3 /unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | |
parent | c49f0d851ab77c9e4d782b453b4b0428bce903d3 (diff) |
Remove TensorBlock.h and old TensorBlock/BlockMapper
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h | 41 |
1 files changed, 15 insertions, 26 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index db123d8a4..7b7b670ed 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -172,9 +172,8 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable, EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE void run(const Expression& expr, const DefaultDevice& device = DefaultDevice()) { - typedef TensorBlock<ScalarNoConst, StorageIndex, NumDims, Evaluator::Layout> TensorBlock; - typedef TensorBlockMapper<ScalarNoConst, StorageIndex, NumDims, Evaluator::Layout> TensorBlockMapper; - typedef typename TensorBlock::Dimensions TensorBlockDimensions; + typedef TensorBlockV2Mapper<NumDims, Evaluator::Layout, StorageIndex> + TensorBlockMapper; typedef internal::TensorBlockDescriptor<NumDims, StorageIndex> TensorBlockDesc; @@ -192,17 +191,15 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable, evaluator.getResourceRequirements(); const TensorBlockMapper block_mapper( - TensorBlockDimensions(evaluator.dimensions()), requirements.shapeV1(), - requirements.size); + typename TensorBlockDesc::Dimensions(evaluator.dimensions()), + requirements); // Share scratch memory allocator between all blocks. TensorBlockScratch scratch(device); - const StorageIndex total_block_count = block_mapper.total_block_count(); + const StorageIndex total_block_count = block_mapper.blockCount(); for (StorageIndex i = 0; i < total_block_count; ++i) { - TensorBlock block = block_mapper.GetBlockForIndex(i, NULL); - - TensorBlockDesc desc(block.first_coeff_index(), block.block_sizes()); + TensorBlockDesc desc = block_mapper.blockDescriptor(i); evaluator.evalBlockV2(desc, scratch); scratch.reset(); } @@ -226,8 +223,6 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable, template <typename TensorBlockMapper> struct TensorExecutorTilingContext { - typedef typename TensorBlockMapper::Block TensorBlock; - TensorExecutorTilingContext() : buffer(nullptr) {} TensorExecutorTilingContext(const TensorBlockMapper& b_mapper, const TensorOpCost& b_cost, void* b_buffer, @@ -274,9 +269,9 @@ TensorExecutorTilingContext<TensorBlockMapper> GetTensorExecutorTilingContext( TensorBlockMapper block_mapper( typename TensorBlockMapper::Dimensions(evaluator.dimensions()), - requirements.shapeV1(), block_size); + requirements); - block_size = block_mapper.block_dims_total_size(); + block_size = block_mapper.blockTotalSize(); const size_t align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1); const size_t aligned_blocksize = align * @@ -382,9 +377,7 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, static const int NumDims = traits<Expression>::NumDimensions; typedef TensorEvaluator<Expression, ThreadPoolDevice> Evaluator; - typedef TensorBlockMapper<ScalarNoConst, IndexType, NumDims, - Evaluator::Layout> - BlockMapper; + typedef TensorBlockV2Mapper<NumDims, Evaluator::Layout, IndexType> BlockMapper; typedef TensorExecutorTilingContext<BlockMapper> TilingContext; typedef internal::TensorBlockDescriptor<NumDims, IndexType> @@ -408,14 +401,13 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, TensorBlockScratch scratch(device); for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) { - auto block = tiling.block_mapper.GetBlockForIndex(block_idx, nullptr); - TensorBlockDesc desc(block.first_coeff_index(), block.block_sizes()); + TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx); evaluator.evalBlockV2(desc, scratch); scratch.reset(); } }; - device.parallelFor(tiling.block_mapper.total_block_count(), tiling.cost, + device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost, eval_block); } evaluator.cleanup(); @@ -486,9 +478,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback, static const int NumDims = traits<Expression>::NumDimensions; typedef TensorEvaluator<Expression, ThreadPoolDevice> Evaluator; - typedef TensorBlockMapper<ScalarNoConst, IndexType, NumDims, - Evaluator::Layout> - BlockMapper; + typedef TensorBlockV2Mapper<NumDims, Evaluator::Layout, IndexType> BlockMapper; typedef TensorExecutorTilingContext<BlockMapper> TilingContext; typedef internal::TensorBlockDescriptor<NumDims, IndexType> TensorBlockDesc; @@ -518,14 +508,13 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback, for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) { - auto block = - ctx->tiling.block_mapper.GetBlockForIndex(block_idx, nullptr); - TensorBlockDesc desc(block.first_coeff_index(), block.block_sizes()); + TensorBlockDesc desc = + ctx->tiling.block_mapper.blockDescriptor(block_idx); ctx->evaluator.evalBlockV2(desc, scratch); scratch.reset(); } }; - ctx->device.parallelForAsync(ctx->tiling.block_mapper.total_block_count(), + ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(), ctx->tiling.cost, eval_block, [ctx]() { delete ctx; }); }; |