diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 32 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 29 |
2 files changed, 39 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index d61209133..87e8db3fd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -180,6 +180,10 @@ template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex, typename OutputMapper, typename LhsMapper, typename RhsMapper> struct TensorContractionKernel { + // True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C` + // (otherwise beta should be always equal to 1). + enum { HasBeta = false }; + EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_, StorageIndex bk_, StorageIndex bn_) @@ -248,7 +252,9 @@ struct TensorContractionKernel { const OutputMapper& output_mapper, const LhsBlock& lhsBlock, const RhsBlock& rhsBlock, const StorageIndex rows, const StorageIndex depth, const StorageIndex cols, - const ResScalar alpha) { + const ResScalar alpha, const ResScalar beta) { + // Default GEBP kernel does not support beta. + eigen_assert(beta == ResScalar(1)); static const int kComputeStrideFromBlockDimensions = -1; GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, /*strideA*/ kComputeStrideFromBlockDimensions, @@ -772,15 +778,6 @@ struct TensorContractionEvaluatorBase void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; - - // rows in left side - const Index m = this->m_i_size; - - // columns in right side - const Index n = this->m_j_size; - - // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) - this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); this->template evalGemmPartial<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, @@ -866,6 +863,12 @@ struct TensorContractionEvaluatorBase const BlockMemHandle packed_mem = kernel.allocate(this->m_device, &blockA, &blockB); + // If a contraction kernel does not support beta, explicitly initialize + // output buffer with zeroes. + if (!TensorContractionKernel::HasBeta) { + this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); + } + for(Index i2=0; i2<m; i2+=mc) { const Index actual_mc = numext::mini(i2+mc,m)-i2; @@ -874,6 +877,13 @@ struct TensorContractionEvaluatorBase const Index actual_kc = numext::mini(k2 + kc, k_end) - k2; kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc); + // If kernel supports beta, there is no need to initialize output + // buffer with zeroes. + const Scalar alpha = Scalar(1); + const Scalar beta = (TensorContractionKernel::HasBeta && k2 == k_start) + ? Scalar(0) + : Scalar(1); + // series of horizontal blocks for (Index j2 = 0; j2 < n; j2 += nc) { // make sure we don't overshoot right edge of right matrix, then pack block @@ -885,7 +895,7 @@ struct TensorContractionEvaluatorBase // The parameters here are copied from Eigen's GEMM implementation const OutputMapper output_mapper = output.getSubMapper(i2, j2); kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc, - actual_nc, Scalar(1)); + actual_nc, alpha, beta); // We are done with this [i2, j2] output block. if (use_output_kernel && k2 + kc >= k_end) { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 873db5efd..26c9fac17 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -904,14 +904,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const Index nend = n * gn_ + gn(n); for (Index n1 = n * gn_; n1 < nend; n1++) { - if (k == 0) { - // Zero the output memory in parallel. - // On 10000x2x10000 mm zeroing can easily take half of time. - // Zero (bn x m) row. Safe to do here because all kernels that will - // write to this memory depend on completion of this task. - // Note: don't call device_.memset() here. device_.memset() blocks on - // thread pool worker thread, which can lead to underutilization and - // deadlocks. + if (!TensorContractionKernel::HasBeta && k == 0) { + // Zero the output memory in parallel, only if contraction kernel does + // not support `beta`. Otherwise we will pass beta 0.0 to the first + // call to the `TensorContractionKernel::invoke()`. + // + // On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn + // x m) row. Safe to do here because all kernels that will write to + // this memory depend on completion of this task. Note: don't call + // device_.memset() here. device_.memset() blocks on thread pool + // worker thread, which can lead to underutilization and deadlocks. memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar)); } kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local), @@ -936,6 +938,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // (rhs fits into L2$ while lhs only into L3$). const Index nend = n * gn_ + gn(n); const Index mend = m * gm_ + gm(m); + + // NOTE: output = alpha * LHS * RHS + beta * output. + const Scalar alpha = Scalar(1); + const Scalar beta = + (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1); + if (shard_by_col_) { for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index m1 = m * gm_; m1 < mend; m1++) { @@ -944,7 +952,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), - bk(k), bn(n1), Scalar(1)); + bk(k), bn(n1), alpha, beta); // We are done with the last task for the [m1, n1] block. if (k + 1 == nk_) { @@ -961,7 +969,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), - bk(k), bn(n1), Scalar(1)); + bk(k), bn(n1), alpha, beta); // We are done with the last task for the [m1, n1] block. if (k + 1 == nk_) { @@ -1266,7 +1274,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT template <int Alignment> void processBlock(Index block_idx, Index begin, Index end) { Scalar* buf = block_buffers[block_idx]; - ::memset(buf, 0, buffer_size_bytes); TENSOR_CONTRACTION_DISPATCH( evaluator->template evalGemmPartialWithoutOutputKernel, Alignment, |