diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
3 files changed, 247 insertions, 102 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 6ca881f27..6a213096d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -105,7 +105,9 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKern static const int NumDimensions = traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value; static const int Layout = traits<LhsXprType>::Layout; typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val, - typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType; + typename traits<LhsXprType>::PointerType, + typename traits<RhsXprType>::PointerType>::type + PointerType; enum { Flags = 0 @@ -136,6 +138,80 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value; }; +// Helper class to allocate and deallocate temporary memory for packed buffers. +template <typename LhsScalar, typename RhsScalar> +struct TensorContractionBlockMemAllocator { + typedef void* BlockMemHandle; + + template <typename Device> + EIGEN_DEVICE_FUNC static BlockMemHandle allocate(Device& d, const Index bm, + const Index bk, + const Index bn, + LhsScalar** lhs_block, + RhsScalar** rhs_block) { + eigen_assert(lhs_block); + eigen_assert(rhs_block); + BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn); + char* block_mem = static_cast<char*>(d.allocate(sz.lhs_size + sz.rhs_size)); + eigen_assert(block_mem); + *lhs_block = reinterpret_cast<LhsScalar*>(block_mem); + *rhs_block = reinterpret_cast<RhsScalar*>(block_mem + sz.lhs_size); + return block_mem; + } + + template <typename Device> + EIGEN_DEVICE_FUNC static BlockMemHandle allocateSlices( + Device& d, const Index bm, const Index bk, const Index bn, + const Index num_lhs, const Index num_rhs, const Index num_slices, + std::vector<LhsScalar*>* lhs_blocks, + std::vector<RhsScalar*>* rhs_blocks) { + eigen_assert(num_slices > 0); + eigen_assert(num_lhs >= 0 && num_rhs >= 0) + eigen_assert(num_lhs == 0 || lhs_blocks); + eigen_assert(num_rhs == 0 || rhs_blocks); + BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn); + void* block_mem = d.allocate( + (num_lhs * sz.lhs_size + num_rhs * sz.rhs_size) * num_slices); + eigen_assert(block_mem); + char* mem = static_cast<char*>(block_mem); + + for (Index x = 0; x < num_slices; x++) { + if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); + for (Index m = 0; m < num_lhs; m++) { + lhs_blocks[x][m] = reinterpret_cast<LhsScalar*>(mem); + mem += sz.lhs_size; + } + if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); + for (Index n = 0; n < num_rhs; n++) { + rhs_blocks[x][n] = reinterpret_cast<RhsScalar*>(mem); + mem += sz.rhs_size; + } + } + + return block_mem; + } + + template <typename Device> + EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) { + d.deallocate(handle); + } + + private: + struct BlockSizes { + Index lhs_size; + Index rhs_size; + }; + EIGEN_DEVICE_FUNC static BlockSizes ComputeLhsRhsBlockSizes(const Index bm, + const Index bk, + const Index bn) { + Index align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1); + BlockSizes sz; + sz.lhs_size = divup<Index>(bm * bk * sizeof(LhsScalar), align) * align; + sz.rhs_size = divup<Index>(bn * bk * sizeof(RhsScalar), align) * align; + return sz; + } +}; + // WARNING: In this code we assume that Lhs and Rhs tensor expressions are in // ColMajor storage order. This property is guaranteed by the // TensorContractionOp evaluator. TensorContractionKernel specifies how we pack @@ -164,16 +240,28 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, // TensorContractionInputMapper, or some specialization of it based on the // type of tensor expression (e.g. TensorImagePatchOp has optimized input // mapper). -template<typename ResScalar, typename LhsScalar, typename RhsScalar, +template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex, typename OutputMapper, typename LhsMapper, typename RhsMapper> struct TensorContractionKernel { + TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n, + StorageIndex bm, StorageIndex bk, StorageIndex bn) + : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {} + + // Pack blocks of Lhs and Rhs into contiguous blocks in memory. + typedef LhsScalar* LhsBlock; + typedef RhsScalar* RhsBlock; + + // Packed Lhs/Rhs block memory allocator. + typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> + BlockMemAllocator; + typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; + typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; - typedef internal::gemm_pack_lhs<LhsScalar, StorageIndex, - typename LhsMapper::SubMapper, - Traits::mr, Traits::LhsProgress, - typename Traits::LhsPacket4Packing, ColMajor> + typedef internal::gemm_pack_lhs< + LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr, + Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor> LhsPacker; typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex, @@ -186,29 +274,61 @@ struct TensorContractionKernel { /*ConjugateLhs*/ false, /*ConjugateRhs*/ false> GebpKernel; - EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE - static void packLhs(LhsScalar* lhsBlock, - const typename LhsMapper::SubMapper& data_mapper, - const StorageIndex depth, const StorageIndex rows) { - LhsPacker()(lhsBlock, data_mapper, depth, rows, /*stride*/ 0, /*offset*/ 0); + template <typename Device> + EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, + RhsBlock* rhs_block) { + return BlockMemAllocator::allocate(d, bm, bk, bn, lhs_block, rhs_block); + } + + template <typename Device> + EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices( + Device& d, const StorageIndex num_lhs, const StorageIndex num_rhs, + const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks, + std::vector<RhsBlock>* rhs_blocks) { + return BlockMemAllocator::allocateSlices( + d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks); + } + + template <typename Device> + EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) { + BlockMemAllocator::deallocate(d, handle); } - EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE - static void packRhs(RhsScalar* rhsBlock, - const typename RhsMapper::SubMapper& data_mapper, - const StorageIndex depth, const StorageIndex cols) { - RhsPacker()(rhsBlock, data_mapper, depth, cols); + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( + LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, + const StorageIndex depth, const StorageIndex rows) { + LhsPacker()(*lhsBlock, data_mapper, depth, rows, /*stride*/ 0, + /*offset*/ 0); } - EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE - static void invoke(const OutputMapper& output_mapper, - const LhsScalar* lhsBlock, const RhsScalar* rhsBlock, - const StorageIndex rows, const StorageIndex depth, - const StorageIndex cols, const ResScalar alpha) { + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( + RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, + const StorageIndex depth, const StorageIndex cols) { + RhsPacker()(*rhsBlock, data_mapper, depth, cols); + } + + EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( + const OutputMapper& output_mapper, const LhsBlock& lhsBlock, + const RhsBlock& rhsBlock, const StorageIndex rows, + const StorageIndex depth, const StorageIndex cols, + const ResScalar alpha) { + static const int kComputeStrideFromBlockDimensions = -1; GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, - /*strideA*/ -1, /*strideB*/ -1, + /*strideA*/ kComputeStrideFromBlockDimensions, + /*strideB*/ kComputeStrideFromBlockDimensions, /*offsetA*/ 0, /*offsetB*/ 0); } + + private: + // These are dimensions of the original Tensors, and selected block sizes. The + // actual block sizes passed to all function above might be smaller because of + // the partial blocks at the end. + const StorageIndex m; + const StorageIndex k; + const StorageIndex n; + const StorageIndex bm; + const StorageIndex bk; + const StorageIndex bn; }; } // end namespace internal @@ -257,7 +377,7 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp public: typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType, - typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType; + typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType; typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested; typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; @@ -340,10 +460,10 @@ struct TensorContractionEvaluatorBase EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType& op, const Device& device) - : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), + : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), op.lhsExpression(), op.rhsExpression()), device), - m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), - op.rhsExpression(), op.lhsExpression()), device), + m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), + op.rhsExpression(), op.lhsExpression()), device), m_device(device), m_output_kernel(op.outputKernel()), m_result(NULL) { @@ -737,11 +857,18 @@ struct TensorContractionEvaluatorBase const Index kc = blocking.kc(); const Index mc = numext::mini(m, blocking.mc()); const Index nc = numext::mini(n, blocking.nc()); - const Index sizeA = mc * kc; - const Index sizeB = kc * nc; - LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar))); - RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar))); + typedef typename TensorContractionKernel::LhsBlock LhsBlock; + typedef typename TensorContractionKernel::RhsBlock RhsBlock; + + LhsBlock blockA; + RhsBlock blockB; + + TensorContractionKernel kernel(m, k_slice, n, mc, kc, nc); + + typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle; + const BlockMemHandle packed_mem = + kernel.allocate(this->m_device, &blockA, &blockB); for(Index i2=0; i2<m; i2+=mc) { @@ -749,22 +876,20 @@ struct TensorContractionEvaluatorBase for (Index k2 = k_start; k2 < k_end; k2 += kc) { // make sure we don't overshoot right edge of left matrix, then pack vertical panel const Index actual_kc = numext::mini(k2 + kc, k_end) - k2; - TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2), - actual_kc, actual_mc); + kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc); // series of horizontal blocks for (Index j2 = 0; j2 < n; j2 += nc) { // make sure we don't overshoot right edge of right matrix, then pack block const Index actual_nc = numext::mini(j2 + nc, n) - j2; - TensorContractionKernel::packRhs(blockB, rhs.getSubMapper(k2, j2), - actual_kc, actual_nc); + kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc, + actual_nc); // call gebp (matrix kernel) // The parameters here are copied from Eigen's GEMM implementation const OutputMapper output_mapper = output.getSubMapper(i2, j2); - TensorContractionKernel::invoke(output_mapper, blockA, blockB, - actual_mc, actual_kc, actual_nc, - Scalar(1)); + kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc, + actual_nc, Scalar(1)); // We are done with this [i2, j2] output block. if (use_output_kernel && k2 + kc >= k_end) { @@ -775,8 +900,7 @@ struct TensorContractionEvaluatorBase } } - this->m_device.deallocate(blockA); - this->m_device.deallocate(blockB); + kernel.deallocate(this->m_device, packed_mem); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 142492603..1be823fd1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -24,12 +24,17 @@ enum { */ /// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which /// is scalar * for CoeffLoader. -template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> struct CoeffLoader; -template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t, - int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, - template <class> class MakePointer_ = MakePointer> class BaseTensorContractionMapper; +template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> +struct CoeffLoader; -template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> struct CoeffLoader { +template <typename Scalar, typename Index, int side, typename Tensor, + typename nocontract_t, typename contract_t, int packet_size, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, + template <class> class MakePointer_ = MakePointer> +class BaseTensorContractionMapper; + +template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> +struct CoeffLoader { enum { DirectOffsets = false }; @@ -40,6 +45,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer eigen_assert(false && "unsupported"); } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type + data() const { + eigen_assert(false && "unsupported"); + return NULL; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); } template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -48,12 +59,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer return m_tensor.template packet<LoadMode>(index); } - private: const Tensor m_tensor; }; -template <typename Tensor, template <class> class MakePointer_> struct CoeffLoader<Tensor, true, MakePointer_> { +template <typename Tensor, template <class> class MakePointer_> +struct CoeffLoader<Tensor, true, MakePointer_> { enum { DirectOffsets = true }; @@ -64,6 +75,11 @@ template <typename Tensor, template <class> class MakePointer_> struct CoeffLoad m_data += offset; } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type + data() const { + return m_data; + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); } template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE @@ -214,6 +230,17 @@ class SimpleTensorContractionMapper { return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1; } + const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const { + return m_tensor; + } + + const nocontract_t& nocontract_strides() const { + return m_nocontract_strides; + } + const nocontract_t& ij_strides() const { return m_ij_strides; } + const contract_t& contract_strides() const { return m_contract_strides; } + const contract_t& k_strides() const { return m_k_strides; } + protected: CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor; const nocontract_t m_nocontract_strides; @@ -445,6 +472,10 @@ class TensorContractionSubMapper { return false; } + const ParentMapper& base_mapper() const { return m_base_mapper; } + Index vert_offset() const { return m_vert_offset; } + Index horiz_offset() const { return m_horiz_offset; } + private: ParentMapper m_base_mapper; const Index m_vert_offset; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index adf57c892..caa8d1767 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -280,6 +280,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper> TensorContractionKernel; + typedef typename TensorContractionKernel::LhsBlock LhsBlock; + typedef typename TensorContractionKernel::RhsBlock RhsBlock; + typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle; + Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, @@ -311,7 +315,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT gm_(gm), gn_(gn), nm0_(nm0), - nn0_(nn0) + nn0_(nn0), + kernel_(m_, k_, n_, bm_, bk_, bn_) { // These two options are mutually exclusive. eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only)); @@ -342,26 +347,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } // Allocate memory for packed rhs/lhs matrices. - size_t align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1); - size_t lhs_size = - divup<size_t>(bm_ * bk_ * sizeof(LhsScalar), align) * align; - size_t rhs_size = - divup<size_t>(bn_ * bk_ * sizeof(RhsScalar), align) * align; - packed_mem_ = static_cast<char*>(device_.allocate( - (nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1))); - char* mem = static_cast<char*>(packed_mem_); - for (Index x = 0; x < numext::mini<Index>(nk_, P - 1); x++) { - packed_lhs_[x].resize(nm0_); - for (Index m = 0; m < nm0_; m++) { - packed_lhs_[x][m] = reinterpret_cast<LhsScalar*>(mem); - mem += lhs_size; - } - packed_rhs_[x].resize(nn0_); - for (Index n = 0; n < nn0_; n++) { - packed_rhs_[x][n] = reinterpret_cast<RhsScalar*>(mem); - mem += rhs_size; - } - } + packed_mem_ = kernel_.allocateSlices( // + device_, // + /*num_lhs=*/nm0_, // + /*num_rhs=*/nn0_, // + /*num_slices=*/std::min<Index>(nk_, P - 1), // + packed_lhs_, packed_rhs_); if (parallelize_by_sharding_dim_only_) { const int num_worker_threads = device_.numThreadsInPool(); @@ -373,14 +364,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT std::memory_order_relaxed); Index num_blocks = num_worker_threads * gn_; - thread_local_packed_mem_ = device_.allocate(num_blocks * rhs_size); - mem = static_cast<char*>(thread_local_packed_mem_); + thread_local_packed_mem_ = kernel_.allocateSlices( // + device_, // + /*num_lhs=*/0, // + /*num_rhs=*/num_blocks, // + /*num_slices=*/1, // + /*lhs_blocks=*/nullptr, &thread_local_packed_rhs_); - thread_local_packed_rhs_.resize(num_blocks, nullptr); - for (Index i = 0; i < num_blocks; ++i) { - thread_local_packed_rhs_[i] = reinterpret_cast<RhsScalar*>(mem); - mem += rhs_size; - } } else { can_use_thread_local_packed_ = new std::atomic<bool>[nm_]; for (int i = 0; i < nm_; ++i) @@ -388,14 +378,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT std::memory_order_relaxed); Index num_blocks = num_worker_threads * gm_; - thread_local_packed_mem_ = device_.allocate(num_blocks * lhs_size); - mem = static_cast<char*>(thread_local_packed_mem_); - - thread_local_packed_lhs_.resize(num_blocks, nullptr); - for (Index i = 0; i < num_blocks; ++i) { - thread_local_packed_lhs_[i] = reinterpret_cast<LhsScalar*>(mem); - mem += lhs_size; - } + thread_local_packed_mem_ = kernel_.allocateSlices( // + device_, // + /*num_lhs=*/num_blocks, // + /*num_rhs=*/0, // + /*num_slices=*/1, &thread_local_packed_lhs_, // + /*rhs_blocks=*/nullptr); } } } @@ -405,9 +393,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m]; delete[] state_kernel_[x]; } - device_.deallocate(packed_mem_); + kernel_.deallocate(device_, packed_mem_); if (parallelize_by_sharding_dim_only_) { - device_.deallocate(thread_local_packed_mem_); + kernel_.deallocate(device_, thread_local_packed_mem_); delete[] can_use_thread_local_packed_; } } @@ -455,6 +443,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // coarsening). const Index nm0_; const Index nn0_; + // Tensor contraction kernel. + TensorContractionKernel kernel_; // Parallelization strategy. // @@ -491,9 +481,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // actively executing + one to track completion of kernels in the second // slice. static const Index P = 3; - void* packed_mem_; - std::vector<LhsScalar*> packed_lhs_[P - 1]; - std::vector<RhsScalar*> packed_rhs_[P - 1]; + + // Handle to the allocated temporary storage for Lhs/Rhs blocks. + BlockMemHandle packed_mem_; + std::vector<LhsBlock> packed_lhs_[P - 1]; + std::vector<RhsBlock> packed_rhs_[P - 1]; // If we choose to parallelize only by the sharding dimension, each thread // will have it's own "thead local" (not a c++ thread local storage) memory @@ -511,11 +503,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // completion of the K-1 kernel, so we have to allocate "global" packed_lhs_ // and packed_rhs_ to allow kernels to be executed later on a thread // different from the thread that was used for packing. - void* thread_local_packed_mem_; + BlockMemHandle thread_local_packed_mem_; - // Only one of these will beinitialized depending on shard_by_col value. - std::vector<LhsScalar*> thread_local_packed_lhs_; - std::vector<RhsScalar*> thread_local_packed_rhs_; + // Only one of these will be initialized depending on shard_by_col value. + std::vector<LhsBlock> thread_local_packed_lhs_; + std::vector<RhsBlock> thread_local_packed_rhs_; // After a particular shard for Kth slice missed thread local execution // opportunity (K-1 slice didn't complete kernels execution), we can no @@ -532,7 +524,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT std::atomic<Index> state_packing_ready_[P]; std::atomic<Index> state_switch_[P]; - LhsScalar* packed_lhs(Index m, Index k, Index m1, bool use_thread_local) { + LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) { if (use_thread_local) { eigen_assert(!shard_by_col_); @@ -546,7 +538,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } } - RhsScalar* packed_rhs(Index n, Index k, Index n1, bool use_thread_local) { + RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) { if (use_thread_local) { eigen_assert(shard_by_col_); @@ -580,7 +572,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } else { // If we can't guarantee that all kernels in `k` slice will be // executed sequentially in current thread, it's no longer safe to use - // thread local memory in followig slices along the k dimensions. + // thread local memory in following slices along the k dimensions. eigen_assert(k > 0); can_use_thread_local_packed_[m].store(false, std::memory_order_relaxed); @@ -589,9 +581,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const Index mend = m * gm_ + gm(m); for (Index m1 = m * gm_; m1 < mend; m1++) - TensorContractionKernel::packLhs(packed_lhs(m, k, m1, use_thread_local), - lhs_.getSubMapper(m1 * bm_, k * bk_), - bk(k), bm(m1)); + kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local), + lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1)); if (!parallel_pack_ && shard_by_col_) { assert(!use_thread_local); @@ -634,9 +625,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // deadlocks. memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar)); } - TensorContractionKernel::packRhs(packed_rhs(n, k, n1, use_thread_local), - rhs_.getSubMapper(k * bk_, n1 * bn_), - bk(k), bn(n1)); + kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local), + rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1)); } if (parallel_pack_ || shard_by_col_) { @@ -661,7 +651,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index m1 = m * gm_; m1 < mend; m1++) { const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); - TensorContractionKernel::invoke( + kernel_.invoke( output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), @@ -678,7 +668,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT for (Index m1 = m * gm_; m1 < mend; m1++) for (Index n1 = n * gn_; n1 < nend; n1++) { const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); - TensorContractionKernel::invoke( + kernel_.invoke( output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), |