From 01fd4096d395e7b816459f571bf2328c8435cc37 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 10 Jul 2018 13:16:38 -0700 Subject: Fuse computations into the Tensor contractions using output kernel --- .../CXX11/src/Tensor/TensorContractionThreadPool.h | 51 ++++++++++++++++------ 1 file changed, 38 insertions(+), 13 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 3c007b183..d7536bd6a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -56,16 +56,16 @@ struct packRhsAndKernelArg { } // end namespace internal #endif // EIGEN_USE_SIMPLE_THREAD_POOL -template -struct TensorEvaluator, ThreadPoolDevice> : - public TensorContractionEvaluatorBase, ThreadPoolDevice> > { +template +struct TensorEvaluator, ThreadPoolDevice> : + public TensorContractionEvaluatorBase, ThreadPoolDevice> > { typedef ThreadPoolDevice Device; - typedef TensorEvaluator, Device> Self; + typedef TensorEvaluator, Device> Self; typedef TensorContractionEvaluatorBase Base; - typedef TensorContractionOp XprType; + typedef TensorContractionOp XprType; typedef typename internal::remove_const::type Scalar; typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; @@ -308,7 +308,7 @@ struct TensorEvaluatorm_k_strides); Context(this->m_device, num_threads, lhs, rhs, buffer, m, n, + OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack) .run(); @@ -319,16 +319,18 @@ struct TensorEvaluator class Context { public: - Context(const Device& device, int num_threads, LhsMapper& lhs, + Context(const Self* self, int num_threads, LhsMapper& lhs, RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, bool parallel_pack) - : device_(device), + : device_(self->m_device), lhs_(lhs), rhs_(rhs), buffer_(buffer), output_(buffer, tm), + output_kernel_(self->m_output_kernel), + tensor_contraction_params_(self->m_tensor_contraction_params), num_threads_(num_threads), shard_by_col_(shard_by_col), parallel_pack_(parallel_pack), @@ -420,6 +422,8 @@ struct TensorEvaluator::value, + "SimpleThreadPool does not support contraction output kernels."); template void evalProduct(Scalar* buffer) const { @@ -1065,6 +1086,10 @@ struct TensorEvaluator::value, + "XSMM does not support contraction output kernels."); + template class ContextXsmm { public: -- cgit v1.2.3