aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 11:06:02 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-02 11:06:02 -0700
commit6e40454a6e6cc57c07c7340148657c985ca6c928 (patch)
tree28e623b2492d69bcff8fa9c54b3a0e64eea08a69 /unsupported/Eigen/CXX11/src/Tensor
parentbd0fac456f8ba4fa980a1cbca4b86ac207b82751 (diff)
Add beta to TensorContractionKernel and make memset optional
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h32
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h29
2 files changed, 39 insertions, 22 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index d61209133..87e8db3fd 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -180,6 +180,10 @@ template <typename ResScalar, typename LhsScalar, typename RhsScalar,
typename StorageIndex, typename OutputMapper, typename LhsMapper,
typename RhsMapper>
struct TensorContractionKernel {
+ // True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C`
+ // (otherwise beta should be always equal to 1).
+ enum { HasBeta = false };
+
EIGEN_DEVICE_FUNC
TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_,
StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
@@ -248,7 +252,9 @@ struct TensorContractionKernel {
const OutputMapper& output_mapper, const LhsBlock& lhsBlock,
const RhsBlock& rhsBlock, const StorageIndex rows,
const StorageIndex depth, const StorageIndex cols,
- const ResScalar alpha) {
+ const ResScalar alpha, const ResScalar beta) {
+ // Default GEBP kernel does not support beta.
+ eigen_assert(beta == ResScalar(1));
static const int kComputeStrideFromBlockDimensions = -1;
GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha,
/*strideA*/ kComputeStrideFromBlockDimensions,
@@ -772,15 +778,6 @@ struct TensorContractionEvaluatorBase
void evalGemm(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;
-
- // rows in left side
- const Index m = this->m_i_size;
-
- // columns in right side
- const Index n = this->m_j_size;
-
- // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
- this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
this->template evalGemmPartial<lhs_inner_dim_contiguous,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered,
@@ -866,6 +863,12 @@ struct TensorContractionEvaluatorBase
const BlockMemHandle packed_mem =
kernel.allocate(this->m_device, &blockA, &blockB);
+ // If a contraction kernel does not support beta, explicitly initialize
+ // output buffer with zeroes.
+ if (!TensorContractionKernel::HasBeta) {
+ this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
+ }
+
for(Index i2=0; i2<m; i2+=mc)
{
const Index actual_mc = numext::mini(i2+mc,m)-i2;
@@ -874,6 +877,13 @@ struct TensorContractionEvaluatorBase
const Index actual_kc = numext::mini(k2 + kc, k_end) - k2;
kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
+ // If kernel supports beta, there is no need to initialize output
+ // buffer with zeroes.
+ const Scalar alpha = Scalar(1);
+ const Scalar beta = (TensorContractionKernel::HasBeta && k2 == k_start)
+ ? Scalar(0)
+ : Scalar(1);
+
// series of horizontal blocks
for (Index j2 = 0; j2 < n; j2 += nc) {
// make sure we don't overshoot right edge of right matrix, then pack block
@@ -885,7 +895,7 @@ struct TensorContractionEvaluatorBase
// The parameters here are copied from Eigen's GEMM implementation
const OutputMapper output_mapper = output.getSubMapper(i2, j2);
kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
- actual_nc, Scalar(1));
+ actual_nc, alpha, beta);
// We are done with this [i2, j2] output block.
if (use_output_kernel && k2 + kc >= k_end) {
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 873db5efd..26c9fac17 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -904,14 +904,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index nend = n * gn_ + gn(n);
for (Index n1 = n * gn_; n1 < nend; n1++) {
- if (k == 0) {
- // Zero the output memory in parallel.
- // On 10000x2x10000 mm zeroing can easily take half of time.
- // Zero (bn x m) row. Safe to do here because all kernels that will
- // write to this memory depend on completion of this task.
- // Note: don't call device_.memset() here. device_.memset() blocks on
- // thread pool worker thread, which can lead to underutilization and
- // deadlocks.
+ if (!TensorContractionKernel::HasBeta && k == 0) {
+ // Zero the output memory in parallel, only if contraction kernel does
+ // not support `beta`. Otherwise we will pass beta 0.0 to the first
+ // call to the `TensorContractionKernel::invoke()`.
+ //
+ // On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn
+ // x m) row. Safe to do here because all kernels that will write to
+ // this memory depend on completion of this task. Note: don't call
+ // device_.memset() here. device_.memset() blocks on thread pool
+ // worker thread, which can lead to underutilization and deadlocks.
memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
}
kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
@@ -936,6 +938,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// (rhs fits into L2$ while lhs only into L3$).
const Index nend = n * gn_ + gn(n);
const Index mend = m * gm_ + gm(m);
+
+ // NOTE: output = alpha * LHS * RHS + beta * output.
+ const Scalar alpha = Scalar(1);
+ const Scalar beta =
+ (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
+
if (shard_by_col_) {
for (Index n1 = n * gn_; n1 < nend; n1++) {
for (Index m1 = m * gm_; m1 < mend; m1++) {
@@ -944,7 +952,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
- bk(k), bn(n1), Scalar(1));
+ bk(k), bn(n1), alpha, beta);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
@@ -961,7 +969,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
- bk(k), bn(n1), Scalar(1));
+ bk(k), bn(n1), alpha, beta);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
@@ -1266,7 +1274,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void processBlock(Index block_idx, Index begin, Index end) {
Scalar* buf = block_buffers[block_idx];
- ::memset(buf, 0, buffer_size_bytes);
TENSOR_CONTRACTION_DISPATCH(
evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,