diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 873db5efd..26c9fac17 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -904,14 +904,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const Index nend = n * gn_ + gn(n); for (Index n1 = n * gn_; n1 < nend; n1++) { - if (k == 0) { - // Zero the output memory in parallel. - // On 10000x2x10000 mm zeroing can easily take half of time. - // Zero (bn x m) row. Safe to do here because all kernels that will - // write to this memory depend on completion of this task. - // Note: don't call device_.memset() here. device_.memset() blocks on - // thread pool worker thread, which can lead to underutilization and - // deadlocks. + if (!TensorContractionKernel::HasBeta && k == 0) { + // Zero the output memory in parallel, only if contraction kernel does + // not support `beta`. Otherwise we will pass beta 0.0 to the first + // call to the `TensorContractionKernel::invoke()`. + // + // On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn + // x m) row. Safe to do here because all kernels that will write to + // this memory depend on completion of this task. Note: don't call + // device_.memset() here. device_.memset() blocks on thread pool + // worker thread, which can lead to underutilization and deadlocks. memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar)); } kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local), @@ -936,6 +938,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // (rhs fits into L2$ while lhs only into L3$). const Index nend = n * gn_ + gn(n); const Index mend = m * gm_ + gm(m); + + // NOTE: output = alpha * LHS * RHS + beta * output. + const Scalar alpha = Scalar(1); + const Scalar beta = + (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1); + if (shard_by_col_) { for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index m1 = m * gm_; m1 < mend; m1++) { @@ -944,7 +952,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), - bk(k), bn(n1), Scalar(1)); + bk(k), bn(n1), alpha, beta); // We are done with the last task for the [m1, n1] block. if (k + 1 == nk_) { @@ -961,7 +969,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), - bk(k), bn(n1), Scalar(1)); + bk(k), bn(n1), alpha, beta); // We are done with the last task for the [m1, n1] block. if (k + 1 == nk_) { @@ -1266,7 +1274,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT template <int Alignment> void processBlock(Index block_idx, Index begin, Index end) { Scalar* buf = block_buffers[block_idx]; - ::memset(buf, 0, buffer_size_bytes); TENSOR_CONTRACTION_DISPATCH( evaluator->template evalGemmPartialWithoutOutputKernel, Alignment, |