diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-10 13:16:38 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2018-07-10 13:16:38 -0700 |
commit | 01fd4096d395e7b816459f571bf2328c8435cc37 (patch) | |
tree | 02b928b34f77c3e63126c3175b6ea06174818f51 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | |
parent | 5539587b1f5b5922b2419b0a4468cf2f393def51 (diff) |
Fuse computations into the Tensor contractions using output kernel
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 51 |
1 files changed, 38 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 3c007b183..d7536bd6a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -56,16 +56,16 @@ struct packRhsAndKernelArg { } // end namespace internal #endif // EIGEN_USE_SIMPLE_THREAD_POOL -template<typename Indices, typename LeftArgType, typename RightArgType> -struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> : - public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > { +template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType> +struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> : + public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > { typedef ThreadPoolDevice Device; - typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; + typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; typedef TensorContractionEvaluatorBase<Self> Base; - typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; + typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -308,7 +308,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT this->m_k_strides); Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper, - OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n, + OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack) .run(); @@ -319,16 +319,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT typename LhsMapper, typename RhsMapper, typename OutputMapper> class Context { public: - Context(const Device& device, int num_threads, LhsMapper& lhs, + Context(const Self* self, int num_threads, LhsMapper& lhs, RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, bool parallel_pack) - : device_(device), + : device_(self->m_device), lhs_(lhs), rhs_(rhs), buffer_(buffer), output_(buffer, tm), + output_kernel_(self->m_output_kernel), + tensor_contraction_params_(self->m_tensor_contraction_params), num_threads_(num_threads), shard_by_col_(shard_by_col), parallel_pack_(parallel_pack), @@ -420,6 +422,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT RhsMapper& rhs_; Scalar* const buffer_; OutputMapper output_; + OutputKernelType output_kernel_; + TensorContractionParams tensor_contraction_params_; const int num_threads_; const bool shard_by_col_; const bool parallel_pack_; @@ -536,19 +540,32 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT const Index mend = m * gm_ + gm(m); if (shard_by_col_) { for (Index n1 = n * gn_; n1 < nend; n1++) { - for (Index m1 = m * gm_; m1 < mend; m1++) - GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_), - packed_lhs_[k % (P - 1)][m1], + for (Index m1 = m * gm_; m1 < mend; m1++) { + const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); + GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1], packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), Scalar(1), -1, -1, 0, 0); + + // We are done with the last task for the [m1, n1] block. + if (k + 1 == nk_) { + output_kernel_(output_mapper, tensor_contraction_params_, + m1 * bm_, n1 * bn_, bm(m1), bn(n1)); + } + } } } else { for (Index m1 = m * gm_; m1 < mend; m1++) for (Index n1 = n * gn_; n1 < nend; n1++) { - GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_), - packed_lhs_[k % (P - 1)][m1], + const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); + GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1], packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), Scalar(1), -1, -1, 0, 0); + + // We are done with the last task for the [m1, n1] block. + if (k + 1 == nk_) { + output_kernel_(output_mapper, tensor_contraction_params_, + m1 * bm_, n1 * bn_, bm(m1), bn(n1)); + } } } signal_kernel(m, n, k + 1, false); @@ -747,6 +764,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } #else // EIGEN_USE_SIMPLE_THREAD_POOL + // TODO(ezhulenev): SimpleThreadPool will be removed in the future, and seems + // like it's not worth adding output kernel support here. + static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value, + "SimpleThreadPool does not support contraction output kernels."); template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> void evalProduct(Scalar* buffer) const { @@ -1065,6 +1086,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT } #if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) + // TODO(ezhulenev): Add support for output kernels and LIBXSMM. + static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value, + "XSMM does not support contraction output kernels."); + template<int Alignment> class ContextXsmm { public: |