From 6e40454a6e6cc57c07c7340148657c985ca6c928 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 2 Oct 2019 11:06:02 -0700 Subject: Add beta to TensorContractionKernel and make memset optional --- .../Eigen/CXX11/src/Tensor/TensorContraction.h | 32 ++++++++++++++-------- 1 file changed, 21 insertions(+), 11 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index d61209133..87e8db3fd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -180,6 +180,10 @@ template struct TensorContractionKernel { + // True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C` + // (otherwise beta should be always equal to 1). + enum { HasBeta = false }; + EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_, StorageIndex bk_, StorageIndex bn_) @@ -248,7 +252,9 @@ struct TensorContractionKernel { const OutputMapper& output_mapper, const LhsBlock& lhsBlock, const RhsBlock& rhsBlock, const StorageIndex rows, const StorageIndex depth, const StorageIndex cols, - const ResScalar alpha) { + const ResScalar alpha, const ResScalar beta) { + // Default GEBP kernel does not support beta. + eigen_assert(beta == ResScalar(1)); static const int kComputeStrideFromBlockDimensions = -1; GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, /*strideA*/ kComputeStrideFromBlockDimensions, @@ -772,15 +778,6 @@ struct TensorContractionEvaluatorBase void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; - - // rows in left side - const Index m = this->m_i_size; - - // columns in right side - const Index n = this->m_j_size; - - // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) - this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); this->template evalGemmPartialm_device, &blockA, &blockB); + // If a contraction kernel does not support beta, explicitly initialize + // output buffer with zeroes. + if (!TensorContractionKernel::HasBeta) { + this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); + } + for(Index i2=0; i2= k_end) { -- cgit v1.2.3