From e4785326255c536214d2cead384477c35e3bdcc6 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Fri, 27 Jul 2018 12:36:34 -0700 Subject: Reduce the number of template specializations of classes related to tensor contraction to reduce binary size. --- .../CXX11/src/Tensor/TensorContractionThreadPool.h | 118 +++++++++------------ 1 file changed, 49 insertions(+), 69 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 8b86d7aaf..182c5f7f9 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -71,8 +71,7 @@ struct TensorEvaluator + template void evalProduct(Scalar* buffer) const { const Index m = this->m_i_size; const Index n = this->m_j_size; @@ -96,39 +95,6 @@ struct TensorEvaluator::type - LhsScalar; - typedef - typename internal::remove_const::type - RhsScalar; - typedef typename internal::gebp_traits Traits; - typedef TensorEvaluator LeftEvaluator; - typedef TensorEvaluator RightEvaluator; - typedef internal::TensorContractionInputMapper< - LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t, - contract_t, internal::packet_traits::size, - lhs_inner_dim_contiguous, false, Unaligned> - LhsMapper; - typedef internal::TensorContractionInputMapper< - RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t, - contract_t, internal::packet_traits::size, - rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned> - RhsMapper; - typedef internal::blas_data_mapper OutputMapper; - typedef internal::gemm_pack_lhs - LhsPacker; - typedef internal::gemm_pack_rhs< - RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> - RhsPacker; - typedef internal::gebp_kernel - GebpKernel; - - - // Compute a set of algorithm parameters: // - kernel block sizes (bm, bn, bk) // - task grain sizes (number of kernels executed per task: gm, gn) @@ -158,14 +124,14 @@ struct TensorEvaluator blocking(k, m, n, 2); bm = blocking.mc(); bn = blocking.nc(); bk = blocking.kc(); } else { - internal::TensorContractionBlocking blocking(k, m, n, 2); bm = blocking.mc(); @@ -187,29 +153,22 @@ struct TensorEvaluatortemplate evalGemv(buffer); - else - this->template evalGemm(buffer); + TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, + Unaligned, (buffer)); return; } // Now that we know number of threads, recalculate sharding and blocking. shard_by_col = shardByCol(m, n, num_threads); if (shard_by_col) { - internal::TensorContractionBlocking blocking(k, m, n, num_threads); bm = blocking.mc(); bn = blocking.nc(); bk = blocking.kc(); } else { - internal::TensorContractionBlocking blocking(k, m, n, num_threads); bm = blocking.mc(); @@ -257,34 +216,55 @@ struct TensorEvaluatorm_leftImpl, this->m_left_nocontract_strides, - this->m_i_strides, this->m_left_contracting_strides, - this->m_k_strides); + #define CONTEXT_ARGS \ + (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ + nn0, shard_by_col, parallel_pack) \ + .run() + + TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS); - RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, - this->m_j_strides, this->m_right_contracting_strides, - this->m_k_strides); +#undef CONTEXT_ARGS - Context(this, num_threads, lhs, rhs, buffer, m, n, - k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, - shard_by_col, parallel_pack) - .run(); } // Context coordinates a single parallel gemm operation. - template + template class Context { public: - Context(const Self* self, int num_threads, LhsMapper& lhs, - RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, - Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, - Index gn, Index nm0, Index nn0, bool shard_by_col, + typedef internal::TensorContractionInputMapper< + LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t, + contract_t, internal::packet_traits::size, + lhs_inner_dim_contiguous, false, Unaligned> + LhsMapper; + typedef internal::TensorContractionInputMapper< + RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t, + contract_t, internal::packet_traits::size, + rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned> + RhsMapper; + typedef internal::gemm_pack_lhs + LhsPacker; + typedef internal::gemm_pack_rhs< + RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> + RhsPacker; + typedef internal::blas_data_mapper OutputMapper; + typedef internal::gebp_kernel + GebpKernel; + + Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, + Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, + Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, bool parallel_pack) : device_(self->m_device), - lhs_(lhs), - rhs_(rhs), + lhs_(self->m_leftImpl, self->m_left_nocontract_strides, + self->m_i_strides, self->m_left_contracting_strides, + self->m_k_strides), + rhs_(self->m_rightImpl, self->m_right_nocontract_strides, + self->m_j_strides, self->m_right_contracting_strides, + self->m_k_strides), buffer_(buffer), output_(buffer, tm), output_kernel_(self->m_output_kernel), @@ -376,8 +356,8 @@ struct TensorEvaluator