diff options
Diffstat (limited to 'unsupported')
3 files changed, 109 insertions, 121 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 0e69cd40c..57b5339d1 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -177,9 +177,9 @@ struct NoOpOutputKernel { */ template <typename Index, typename Scalar> EIGEN_ALWAYS_INLINE void operator()( - const OutputKernel::OutputMapper<Index, Scalar>& output_mapper, - const TensorContractionParams& params, Index i, Index j, Index num_rows, - Index num_cols) const {} + const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/, + const TensorContractionParams& /*params*/, Index /*i*/, + Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {} }; template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel> @@ -467,42 +467,58 @@ struct TensorContractionEvaluatorBase } } - EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { - if (this->m_lhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer); - } - } - else { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer); - } - } +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<true, true, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<true, true, false, ALIGNMENT>ARGS; \ + } \ + } \ + else { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<true, false, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<true, false, false, ALIGNMENT>ARGS; \ + } \ + } \ + } \ + else { \ + if (this->m_rhs_inner_dim_contiguous) { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<false, true, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<false, true, false, ALIGNMENT>ARGS; \ + } \ + } \ + else { \ + if (this->m_rhs_inner_dim_reordered) { \ + METHOD<false, false, true, ALIGNMENT>ARGS; \ + } \ + else { \ + METHOD<false, false, false, ALIGNMENT>ARGS; \ + } \ + } \ } - else { - if (this->m_rhs_inner_dim_contiguous) { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer); - } - } - else { - if (this->m_rhs_inner_dim_reordered) { - static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer); - } - else { - static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer); - } - } + + EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { + static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer); + } + + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, + bool rhs_inner_dim_reordered, int Alignment> + void evalProductSequential(Scalar* buffer) const { + if (this->m_j_size == 1) { + this->template evalGemv<lhs_inner_dim_contiguous, + rhs_inner_dim_contiguous, rhs_inner_dim_reordered, + Alignment>(buffer); + } else { + this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, + rhs_inner_dim_reordered, Alignment>(buffer); } } @@ -623,7 +639,7 @@ struct TensorContractionEvaluatorBase OutputMapper output(buffer, m); // Sizes of the blocks to load in cache. See the Goto paper for details. - internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1); + internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1); const Index kc = blocking.kc(); const Index mc = numext::mini(m, blocking.mc()); const Index nc = numext::mini(n, blocking.nc()); @@ -976,14 +992,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) { } - template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> - EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const { - if (this->m_j_size == 1) { - this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); - return; - } - - this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); + template <int Alignment> + void evalProduct(Scalar* buffer) const { + TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer)); } }; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h index 8c1af1da8..cf281192c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h @@ -21,13 +21,10 @@ enum { // Default Blocking Strategy -template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol> +template <typename LhsScalar, typename RhsScalar, typename Index, int ShardingType=ShardByCol> class TensorContractionBlocking { public: - typedef typename LhsMapper::Scalar LhsScalar; - typedef typename RhsMapper::Scalar RhsScalar; - /* adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h` requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h` @@ -41,7 +38,7 @@ class TensorContractionBlocking { ../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901: dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function */ - + #if !defined(EIGEN_HIPCC) EIGEN_DEVICE_FUNC #endif diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 8b86d7aaf..182c5f7f9 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -71,8 +71,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {} - template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, - bool rhs_inner_dim_reordered, int Alignment> + template <int Alignment> void evalProduct(Scalar* buffer) const { const Index m = this->m_i_size; const Index n = this->m_j_size; @@ -96,39 +95,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } #endif - typedef - typename internal::remove_const<typename EvalLeftArgType::Scalar>::type - LhsScalar; - typedef - typename internal::remove_const<typename EvalRightArgType::Scalar>::type - RhsScalar; - typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; - typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; - typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; - typedef internal::TensorContractionInputMapper< - LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t, - contract_t, internal::packet_traits<LhsScalar>::size, - lhs_inner_dim_contiguous, false, Unaligned> - LhsMapper; - typedef internal::TensorContractionInputMapper< - RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t, - contract_t, internal::packet_traits<RhsScalar>::size, - rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned> - RhsMapper; - typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; - typedef internal::gemm_pack_lhs<LhsScalar, Index, - typename LhsMapper::SubMapper, Traits::mr, - Traits::LhsProgress, ColMajor> - LhsPacker; - typedef internal::gemm_pack_rhs< - RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> - RhsPacker; - typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, - Traits::mr, Traits::nr, false, false> - GebpKernel; - - - // Compute a set of algorithm parameters: // - kernel block sizes (bm, bn, bk) // - task grain sizes (number of kernels executed per task: gm, gn) @@ -158,14 +124,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // Again, we don't know number of threads yet, so we use 2. Index bm, bn, bk; if (shard_by_col) { - internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, + internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 2); bm = blocking.mc(); bn = blocking.nc(); bk = blocking.kc(); } else { - internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, + internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(k, m, n, 2); bm = blocking.mc(); @@ -187,29 +153,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT if (n == 1) num_threads = 1; if (num_threads == 1) { - // The single-threaded algorithm should be faster in this case. - if (n == 1) - this->template evalGemv<lhs_inner_dim_contiguous, - rhs_inner_dim_contiguous, - rhs_inner_dim_reordered, Alignment>(buffer); - else - this->template evalGemm<lhs_inner_dim_contiguous, - rhs_inner_dim_contiguous, - rhs_inner_dim_reordered, Alignment>(buffer); + TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, + Unaligned, (buffer)); return; } // Now that we know number of threads, recalculate sharding and blocking. shard_by_col = shardByCol(m, n, num_threads); if (shard_by_col) { - internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, + internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, num_threads); bm = blocking.mc(); bn = blocking.nc(); bk = blocking.kc(); } else { - internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, + internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(k, m, n, num_threads); bm = blocking.mc(); @@ -257,34 +216,55 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // more important in this case. if ((shard_by_col ? nm : nn) == 1) parallel_pack = false; - LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, - this->m_i_strides, this->m_left_contracting_strides, - this->m_k_strides); + #define CONTEXT_ARGS \ + (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \ + nn0, shard_by_col, parallel_pack) \ + .run() + + TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS); - RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, - this->m_j_strides, this->m_right_contracting_strides, - this->m_k_strides); +#undef CONTEXT_ARGS - Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper, - OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n, - k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, - shard_by_col, parallel_pack) - .run(); } // Context coordinates a single parallel gemm operation. - template <typename LhsPacker, typename RhsPacker, typename GebpKernel, - typename LhsMapper, typename RhsMapper, typename OutputMapper> + template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, + bool rhs_inner_dim_reordered, int Alignment> class Context { public: - Context(const Self* self, int num_threads, LhsMapper& lhs, - RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, - Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, - Index gn, Index nm0, Index nn0, bool shard_by_col, + typedef internal::TensorContractionInputMapper< + LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t, + contract_t, internal::packet_traits<LhsScalar>::size, + lhs_inner_dim_contiguous, false, Unaligned> + LhsMapper; + typedef internal::TensorContractionInputMapper< + RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t, + contract_t, internal::packet_traits<RhsScalar>::size, + rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned> + RhsMapper; + typedef internal::gemm_pack_lhs<LhsScalar, Index, + typename LhsMapper::SubMapper, Traits::mr, + Traits::LhsProgress, ColMajor> + LhsPacker; + typedef internal::gemm_pack_rhs< + RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> + RhsPacker; + typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; + typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, + Traits::mr, Traits::nr, false, false> + GebpKernel; + + Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, + Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, + Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, bool parallel_pack) : device_(self->m_device), - lhs_(lhs), - rhs_(rhs), + lhs_(self->m_leftImpl, self->m_left_nocontract_strides, + self->m_i_strides, self->m_left_contracting_strides, + self->m_k_strides), + rhs_(self->m_rightImpl, self->m_right_nocontract_strides, + self->m_j_strides, self->m_right_contracting_strides, + self->m_k_strides), buffer_(buffer), output_(buffer, tm), output_kernel_(self->m_output_kernel), @@ -376,8 +356,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT private: Notification done_; const Device& device_; - LhsMapper& lhs_; - RhsMapper& rhs_; + LhsMapper lhs_; + RhsMapper rhs_; Scalar* const buffer_; OutputMapper output_; OutputKernelType output_kernel_; |