diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-07-27 12:36:34 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-07-27 12:36:34 -0700 |
commit | e4785326255c536214d2cead384477c35e3bdcc6 (patch) | |
tree | f25f1edb58ed5f6126de1ff76c6ea52c6d5a311e /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | |
parent | 2ebcb911b27174c5402e4c7af3d2738fd042a5e2 (diff) |
Reduce the number of template specializations of classes related to tensor contraction to reduce binary size.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 105 |
1 files changed, 58 insertions, 47 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 0e69cd40c..57b5339d1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -177,9 +177,9 @@ struct NoOpOutputKernel { */ template <typename Index, typename Scalar> EIGEN_ALWAYS_INLINE void operator()( - const OutputKernel::OutputMapper<Index, Scalar>& output_mapper, - const TensorContractionParams& params, Index i, Index j, Index num_rows, - Index num_cols) const {} + const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/, + const TensorContractionParams& /*params*/, Index /*i*/, + Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {} }; template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel> @@ -467,42 +467,58 @@ struct TensorContractionEvaluatorBase } } - EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { - if (this->m_lhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer); - } - } - else { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer); - } - } +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<true, true, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<true, true, false, ALIGNMENT>ARGS; \ + } \ + } \ + else { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<true, false, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<true, false, false, ALIGNMENT>ARGS; \ + } \ + } \ + } \ + else { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<false, true, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<false, true, false, ALIGNMENT>ARGS; \ + } \ + } \ + else { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<false, false, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<false, false, false, ALIGNMENT>ARGS; \ + } \ + } \ } - else { - if (this->m_rhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer); - } - } - else { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer); - } - } + + EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { + static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer); + } + + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, + bool rhs_inner_dim_reordered, int Alignment> + void evalProductSequential(Scalar* buffer) const { + if (this->m_j_size == 1) { + this->template evalGemv<lhs_inner_dim_contiguous, + rhs_inner_dim_contiguous, rhs_inner_dim_reordered, + Alignment>(buffer); + } else { + this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, + rhs_inner_dim_reordered, Alignment>(buffer); } } @@ -623,7 +639,7 @@ struct TensorContractionEvaluatorBase OutputMapper output(buffer, m); // Sizes of the blocks to load in cache. See the Goto paper for details. - internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1); + internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1); const Index kc = blocking.kc(); const Index mc = numext::mini(m, blocking.mc()); const Index nc = numext::mini(n, blocking.nc()); @@ -976,14 +992,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) { } - template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> - EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const { - if (this->m_j_size == 1) { - this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); - return; - } - - this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); + template <int Alignment> + void evalProduct(Scalar* buffer) const { + TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer)); } }; |