From 9f4988959f1b0394ee027f474f49916543ad2f3c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 27 Sep 2018 11:49:19 -0700 Subject: Remove explicit mkldnn support and redundant TensorContractionKernelBlocking --- .../Eigen/CXX11/src/Tensor/TensorContraction.h | 102 ++++++++++++++++++--- 1 file changed, 87 insertions(+), 15 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 3b22e43e7..d220f82be 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -136,6 +136,81 @@ struct traits::NumDimensions + traits::NumDimensions - 2 * array_size::value; }; +// WARNING: In this code we assume that Lhs and Rhs tensor expressions are in +// ColMajor storage order. This property is guaranteed by the +// TensorContractionOp evaluator. TensorContractionKernel specifies how we pack +// blocks of Lhs and Rhs tensor expressions, and how we invoke matrix +// multiplication for these blocks. Default tensor contraction uses +// gemm_pack_rhs, gemm_pack_lhs and gebp_kernel from Eigen Core (see +// GeneralBlocPanelKernel.h for details). +// +// By specializing contraction kernels we can use other low level libraries to +// perform matrix multiplication, and still rely on Eigen contraction evaluator. +// This also includes full support in TensorContractionThreadPool, assuming that +// underlying gemm do not use it's own threading. +// +// - ResScalar/LhsScalar/RhsScalar - scalar type for the result of +// multiplication, lhs tensor and rhs tensor respectively. +// +// - StorageIndex - index type for the tensor expressions. In practice almost +// always is Eigen::Index. +// +// - OutputMapper provides access to the memory of the output matrix. In +// practice it's always column major blas_data_mapper (it must be of ResScalar +// type). +// +// - LhsMapper/RhsMapper similarly to blas_data_mapper provide a two dimensional +// view into the Lhs/Rhs tensor expressions. In practice it's +// TensorContractionInputMapper, or some specialization of it based on the +// type of tensor expression (e.g. TensorImagePatchOp has optimized input +// mapper). +template +struct TensorContractionKernel { + typedef typename internal::gebp_traits Traits; + + typedef internal::gemm_pack_lhs + LhsPacker; + + typedef internal::gemm_pack_rhs + RhsPacker; + + typedef internal::gebp_kernel + GebpKernel; + + EIGEN_DONT_INLINE + static void packLhs(LhsScalar* lhsBlock, + const typename LhsMapper::SubMapper& data_mapper, + const StorageIndex depth, const StorageIndex rows) { + LhsPacker()(lhsBlock, data_mapper, depth, rows, /*stride*/ 0, /*offset*/ 0); + } + + EIGEN_DONT_INLINE + static void packRhs(RhsScalar* rhsBlock, + const typename RhsMapper::SubMapper& data_mapper, + const StorageIndex depth, const StorageIndex cols) { + RhsPacker()(rhsBlock, data_mapper, depth, cols); + } + + EIGEN_DONT_INLINE + static void invoke(const OutputMapper& output_mapper, + const LhsScalar* lhsBlock, const RhsScalar* rhsBlock, + const StorageIndex rows, const StorageIndex depth, + const StorageIndex cols, const ResScalar alpha) { + GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, + /*strideA*/ -1, /*strideB*/ -1, + /*offsetA*/ 0, /*offsetB*/ 0); + } +}; + } // end namespace internal // Tensor contraction params that should enable to get from output matrix @@ -591,13 +666,9 @@ struct TensorContractionEvaluatorBase // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); - // define mr, nr, and all of my data mapper types + // define data mappers for Lhs and Rhs typedef typename internal::remove_const::type LhsScalar; typedef typename internal::remove_const::type RhsScalar; - typedef typename internal::gebp_traits Traits; - - const Index nr = Traits::nr; - const Index mr = Traits::mr; typedef TensorEvaluator LeftEvaluator; typedef TensorEvaluator RightEvaluator; @@ -619,11 +690,9 @@ struct TensorContractionEvaluatorBase typedef internal::blas_data_mapper OutputMapper; - // Declare GEBP packing and kernel structs - internal::gemm_pack_lhs pack_lhs; - internal::gemm_pack_rhs pack_rhs; - - internal::gebp_kernel gebp; + typedef internal::TensorContractionKernel< + Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper> + TensorContractionKernel; // initialize data mappers LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, @@ -635,7 +704,7 @@ struct TensorContractionEvaluatorBase OutputMapper output(buffer, m); // Sizes of the blocks to load in cache. See the Goto paper for details. - internal::TensorContractionBlocking blocking(k, m, n, 1); + internal::TensorContractionBlocking blocking(k, m, n, 1); const Index kc = blocking.kc(); const Index mc = numext::mini(m, blocking.mc()); const Index nc = numext::mini(n, blocking.nc()); @@ -651,19 +720,22 @@ struct TensorContractionEvaluatorBase for (Index k2 = 0; k2 < k; k2 += kc) { // make sure we don't overshoot right edge of left matrix, then pack vertical panel const Index actual_kc = numext::mini(k2 + kc, k) - k2; - pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0); + TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2), + actual_kc, actual_mc); // series of horizontal blocks for (Index j2 = 0; j2 < n; j2 += nc) { // make sure we don't overshoot right edge of right matrix, then pack block const Index actual_nc = numext::mini(j2 + nc, n) - j2; - pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0); + TensorContractionKernel::packRhs(blockB, rhs.getSubMapper(k2, j2), + actual_kc, actual_nc); // call gebp (matrix kernel) // The parameters here are copied from Eigen's GEMM implementation const OutputMapper output_mapper = output.getSubMapper(i2, j2); - gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc, - Scalar(1), -1, -1, 0, 0); + TensorContractionKernel::invoke(output_mapper, blockA, blockB, + actual_mc, actual_kc, actual_nc, + Scalar(1)); // We are done with this [i2, j2] output block. if (k2 + kc >= k) { -- cgit v1.2.3