aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-04-01 11:47:31 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-04-01 11:47:31 -0700
commit4e2f6de1a8fd9a659dc40ed54fedff9abdef3b1f (patch)
treee510ad53ee053b68327462c0e6d944db5dc362d0 /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent45e65fbb7791e453f88f959111cff45e0fb7dd6f (diff)
Add support for custom packed Lhs/Rhs blocks in tensor contractions
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h202
1 files changed, 163 insertions, 39 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 6ca881f27..6a213096d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -105,7 +105,9 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKern
static const int NumDimensions = traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
static const int Layout = traits<LhsXprType>::Layout;
typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
- typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
+ typename traits<LhsXprType>::PointerType,
+ typename traits<RhsXprType>::PointerType>::type
+ PointerType;
enum {
Flags = 0
@@ -136,6 +138,80 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
};
+// Helper class to allocate and deallocate temporary memory for packed buffers.
+template <typename LhsScalar, typename RhsScalar>
+struct TensorContractionBlockMemAllocator {
+ typedef void* BlockMemHandle;
+
+ template <typename Device>
+ EIGEN_DEVICE_FUNC static BlockMemHandle allocate(Device& d, const Index bm,
+ const Index bk,
+ const Index bn,
+ LhsScalar** lhs_block,
+ RhsScalar** rhs_block) {
+ eigen_assert(lhs_block);
+ eigen_assert(rhs_block);
+ BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
+ char* block_mem = static_cast<char*>(d.allocate(sz.lhs_size + sz.rhs_size));
+ eigen_assert(block_mem);
+ *lhs_block = reinterpret_cast<LhsScalar*>(block_mem);
+ *rhs_block = reinterpret_cast<RhsScalar*>(block_mem + sz.lhs_size);
+ return block_mem;
+ }
+
+ template <typename Device>
+ EIGEN_DEVICE_FUNC static BlockMemHandle allocateSlices(
+ Device& d, const Index bm, const Index bk, const Index bn,
+ const Index num_lhs, const Index num_rhs, const Index num_slices,
+ std::vector<LhsScalar*>* lhs_blocks,
+ std::vector<RhsScalar*>* rhs_blocks) {
+ eigen_assert(num_slices > 0);
+ eigen_assert(num_lhs >= 0 && num_rhs >= 0)
+ eigen_assert(num_lhs == 0 || lhs_blocks);
+ eigen_assert(num_rhs == 0 || rhs_blocks);
+ BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
+ void* block_mem = d.allocate(
+ (num_lhs * sz.lhs_size + num_rhs * sz.rhs_size) * num_slices);
+ eigen_assert(block_mem);
+ char* mem = static_cast<char*>(block_mem);
+
+ for (Index x = 0; x < num_slices; x++) {
+ if (num_lhs > 0) lhs_blocks[x].resize(num_lhs);
+ for (Index m = 0; m < num_lhs; m++) {
+ lhs_blocks[x][m] = reinterpret_cast<LhsScalar*>(mem);
+ mem += sz.lhs_size;
+ }
+ if (num_rhs > 0) rhs_blocks[x].resize(num_rhs);
+ for (Index n = 0; n < num_rhs; n++) {
+ rhs_blocks[x][n] = reinterpret_cast<RhsScalar*>(mem);
+ mem += sz.rhs_size;
+ }
+ }
+
+ return block_mem;
+ }
+
+ template <typename Device>
+ EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) {
+ d.deallocate(handle);
+ }
+
+ private:
+ struct BlockSizes {
+ Index lhs_size;
+ Index rhs_size;
+ };
+ EIGEN_DEVICE_FUNC static BlockSizes ComputeLhsRhsBlockSizes(const Index bm,
+ const Index bk,
+ const Index bn) {
+ Index align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1);
+ BlockSizes sz;
+ sz.lhs_size = divup<Index>(bm * bk * sizeof(LhsScalar), align) * align;
+ sz.rhs_size = divup<Index>(bn * bk * sizeof(RhsScalar), align) * align;
+ return sz;
+ }
+};
+
// WARNING: In this code we assume that Lhs and Rhs tensor expressions are in
// ColMajor storage order. This property is guaranteed by the
// TensorContractionOp evaluator. TensorContractionKernel specifies how we pack
@@ -164,16 +240,28 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
// TensorContractionInputMapper, or some specialization of it based on the
// type of tensor expression (e.g. TensorImagePatchOp has optimized input
// mapper).
-template<typename ResScalar, typename LhsScalar, typename RhsScalar,
+template <typename ResScalar, typename LhsScalar, typename RhsScalar,
typename StorageIndex, typename OutputMapper, typename LhsMapper,
typename RhsMapper>
struct TensorContractionKernel {
+ TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n,
+ StorageIndex bm, StorageIndex bk, StorageIndex bn)
+ : m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {}
+
+ // Pack blocks of Lhs and Rhs into contiguous blocks in memory.
+ typedef LhsScalar* LhsBlock;
+ typedef RhsScalar* RhsBlock;
+
+ // Packed Lhs/Rhs block memory allocator.
+ typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar>
+ BlockMemAllocator;
+ typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle;
+
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
- typedef internal::gemm_pack_lhs<LhsScalar, StorageIndex,
- typename LhsMapper::SubMapper,
- Traits::mr, Traits::LhsProgress,
- typename Traits::LhsPacket4Packing, ColMajor>
+ typedef internal::gemm_pack_lhs<
+ LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr,
+ Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor>
LhsPacker;
typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex,
@@ -186,29 +274,61 @@ struct TensorContractionKernel {
/*ConjugateLhs*/ false, /*ConjugateRhs*/ false>
GebpKernel;
- EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
- static void packLhs(LhsScalar* lhsBlock,
- const typename LhsMapper::SubMapper& data_mapper,
- const StorageIndex depth, const StorageIndex rows) {
- LhsPacker()(lhsBlock, data_mapper, depth, rows, /*stride*/ 0, /*offset*/ 0);
+ template <typename Device>
+ EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block,
+ RhsBlock* rhs_block) {
+ return BlockMemAllocator::allocate(d, bm, bk, bn, lhs_block, rhs_block);
+ }
+
+ template <typename Device>
+ EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(
+ Device& d, const StorageIndex num_lhs, const StorageIndex num_rhs,
+ const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks,
+ std::vector<RhsBlock>* rhs_blocks) {
+ return BlockMemAllocator::allocateSlices(
+ d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks);
+ }
+
+ template <typename Device>
+ EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) {
+ BlockMemAllocator::deallocate(d, handle);
}
- EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
- static void packRhs(RhsScalar* rhsBlock,
- const typename RhsMapper::SubMapper& data_mapper,
- const StorageIndex depth, const StorageIndex cols) {
- RhsPacker()(rhsBlock, data_mapper, depth, cols);
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(
+ LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper,
+ const StorageIndex depth, const StorageIndex rows) {
+ LhsPacker()(*lhsBlock, data_mapper, depth, rows, /*stride*/ 0,
+ /*offset*/ 0);
}
- EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
- static void invoke(const OutputMapper& output_mapper,
- const LhsScalar* lhsBlock, const RhsScalar* rhsBlock,
- const StorageIndex rows, const StorageIndex depth,
- const StorageIndex cols, const ResScalar alpha) {
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(
+ RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper,
+ const StorageIndex depth, const StorageIndex cols) {
+ RhsPacker()(*rhsBlock, data_mapper, depth, cols);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(
+ const OutputMapper& output_mapper, const LhsBlock& lhsBlock,
+ const RhsBlock& rhsBlock, const StorageIndex rows,
+ const StorageIndex depth, const StorageIndex cols,
+ const ResScalar alpha) {
+ static const int kComputeStrideFromBlockDimensions = -1;
GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha,
- /*strideA*/ -1, /*strideB*/ -1,
+ /*strideA*/ kComputeStrideFromBlockDimensions,
+ /*strideB*/ kComputeStrideFromBlockDimensions,
/*offsetA*/ 0, /*offsetB*/ 0);
}
+
+ private:
+ // These are dimensions of the original Tensors, and selected block sizes. The
+ // actual block sizes passed to all function above might be smaller because of
+ // the partial blocks at the end.
+ const StorageIndex m;
+ const StorageIndex k;
+ const StorageIndex n;
+ const StorageIndex bm;
+ const StorageIndex bk;
+ const StorageIndex bn;
};
} // end namespace internal
@@ -257,7 +377,7 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
public:
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType,
- typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType;
+ typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType;
typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
@@ -340,10 +460,10 @@ struct TensorContractionEvaluatorBase
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorContractionEvaluatorBase(const XprType& op, const Device& device)
- : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
+ : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
op.lhsExpression(), op.rhsExpression()), device),
- m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
- op.rhsExpression(), op.lhsExpression()), device),
+ m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
+ op.rhsExpression(), op.lhsExpression()), device),
m_device(device),
m_output_kernel(op.outputKernel()),
m_result(NULL) {
@@ -737,11 +857,18 @@ struct TensorContractionEvaluatorBase
const Index kc = blocking.kc();
const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc());
- const Index sizeA = mc * kc;
- const Index sizeB = kc * nc;
- LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
- RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
+ typedef typename TensorContractionKernel::LhsBlock LhsBlock;
+ typedef typename TensorContractionKernel::RhsBlock RhsBlock;
+
+ LhsBlock blockA;
+ RhsBlock blockB;
+
+ TensorContractionKernel kernel(m, k_slice, n, mc, kc, nc);
+
+ typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
+ const BlockMemHandle packed_mem =
+ kernel.allocate(this->m_device, &blockA, &blockB);
for(Index i2=0; i2<m; i2+=mc)
{
@@ -749,22 +876,20 @@ struct TensorContractionEvaluatorBase
for (Index k2 = k_start; k2 < k_end; k2 += kc) {
// make sure we don't overshoot right edge of left matrix, then pack vertical panel
const Index actual_kc = numext::mini(k2 + kc, k_end) - k2;
- TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2),
- actual_kc, actual_mc);
+ kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
// series of horizontal blocks
for (Index j2 = 0; j2 < n; j2 += nc) {
// make sure we don't overshoot right edge of right matrix, then pack block
const Index actual_nc = numext::mini(j2 + nc, n) - j2;
- TensorContractionKernel::packRhs(blockB, rhs.getSubMapper(k2, j2),
- actual_kc, actual_nc);
+ kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc,
+ actual_nc);
// call gebp (matrix kernel)
// The parameters here are copied from Eigen's GEMM implementation
const OutputMapper output_mapper = output.getSubMapper(i2, j2);
- TensorContractionKernel::invoke(output_mapper, blockA, blockB,
- actual_mc, actual_kc, actual_nc,
- Scalar(1));
+ kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
+ actual_nc, Scalar(1));
// We are done with this [i2, j2] output block.
if (use_output_kernel && k2 + kc >= k_end) {
@@ -775,8 +900,7 @@ struct TensorContractionEvaluatorBase
}
}
- this->m_device.deallocate(blockA);
- this->m_device.deallocate(blockB);
+ kernel.deallocate(this->m_device, packed_mem);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {