aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 11:06:02 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 11:06:02 -0700
commit6e40454a6e6cc57c07c7340148657c985ca6c928 (patch)
tree28e623b2492d69bcff8fa9c54b3a0e64eea08a69 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parentbd0fac456f8ba4fa980a1cbca4b86ac207b82751 (diff)
Add beta to TensorContractionKernel and make memset optional
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h29
1 files changed, 18 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 873db5efd..26c9fac17 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -904,14 +904,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index nend = n * gn_ + gn(n);
for (Index n1 = n * gn_; n1 < nend; n1++) {
- if (k == 0) {
- // Zero the output memory in parallel.
- // On 10000x2x10000 mm zeroing can easily take half of time.
- // Zero (bn x m) row. Safe to do here because all kernels that will
- // write to this memory depend on completion of this task.
- // Note: don't call device_.memset() here. device_.memset() blocks on
- // thread pool worker thread, which can lead to underutilization and
- // deadlocks.
+ if (!TensorContractionKernel::HasBeta && k == 0) {
+ // Zero the output memory in parallel, only if contraction kernel does
+ // not support `beta`. Otherwise we will pass beta 0.0 to the first
+ // call to the `TensorContractionKernel::invoke()`.
+ //
+ // On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn
+ // x m) row. Safe to do here because all kernels that will write to
+ // this memory depend on completion of this task. Note: don't call
+ // device_.memset() here. device_.memset() blocks on thread pool
+ // worker thread, which can lead to underutilization and deadlocks.
memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
}
kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
@@ -936,6 +938,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// (rhs fits into L2$ while lhs only into L3$).
const Index nend = n * gn_ + gn(n);
const Index mend = m * gm_ + gm(m);
+
+ // NOTE: output = alpha * LHS * RHS + beta * output.
+ const Scalar alpha = Scalar(1);
+ const Scalar beta =
+ (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
+
if (shard_by_col_) {
for (Index n1 = n * gn_; n1 < nend; n1++) {
for (Index m1 = m * gm_; m1 < mend; m1++) {
@@ -944,7 +952,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
- bk(k), bn(n1), Scalar(1));
+ bk(k), bn(n1), alpha, beta);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
@@ -961,7 +969,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
- bk(k), bn(n1), Scalar(1));
+ bk(k), bn(n1), alpha, beta);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
@@ -1266,7 +1274,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void processBlock(Index block_idx, Index begin, Index end) {
Scalar* buf = block_buffers[block_idx];
- ::memset(buf, 0, buffer_size_bytes);
TENSOR_CONTRACTION_DISPATCH(
evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,