aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-10 13:16:38 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-10 13:16:38 -0700
commit01fd4096d395e7b816459f571bf2328c8435cc37 (patch)
tree02b928b34f77c3e63126c3175b6ea06174818f51 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent5539587b1f5b5922b2419b0a4468cf2f393def51 (diff)
Fuse computations into the Tensor contractions using output kernel
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h51
1 files changed, 38 insertions, 13 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 3c007b183..d7536bd6a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -56,16 +56,16 @@ struct packRhsAndKernelArg {
} // end namespace internal
#endif // EIGEN_USE_SIMPLE_THREAD_POOL
-template<typename Indices, typename LeftArgType, typename RightArgType>
-struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
- public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
+template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
+struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
+ public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
typedef ThreadPoolDevice Device;
- typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
+ typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base;
- typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
+ typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
@@ -308,7 +308,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
this->m_k_strides);
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
- OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n,
+ OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
shard_by_col, parallel_pack)
.run();
@@ -319,16 +319,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
typename LhsMapper, typename RhsMapper, typename OutputMapper>
class Context {
public:
- Context(const Device& device, int num_threads, LhsMapper& lhs,
+ Context(const Self* self, int num_threads, LhsMapper& lhs,
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
Index gn, Index nm0, Index nn0, bool shard_by_col,
bool parallel_pack)
- : device_(device),
+ : device_(self->m_device),
lhs_(lhs),
rhs_(rhs),
buffer_(buffer),
output_(buffer, tm),
+ output_kernel_(self->m_output_kernel),
+ tensor_contraction_params_(self->m_tensor_contraction_params),
num_threads_(num_threads),
shard_by_col_(shard_by_col),
parallel_pack_(parallel_pack),
@@ -420,6 +422,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
RhsMapper& rhs_;
Scalar* const buffer_;
OutputMapper output_;
+ OutputKernelType output_kernel_;
+ TensorContractionParams tensor_contraction_params_;
const int num_threads_;
const bool shard_by_col_;
const bool parallel_pack_;
@@ -536,19 +540,32 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index mend = m * gm_ + gm(m);
if (shard_by_col_) {
for (Index n1 = n * gn_; n1 < nend; n1++) {
- for (Index m1 = m * gm_; m1 < mend; m1++)
- GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
- packed_lhs_[k % (P - 1)][m1],
+ for (Index m1 = m * gm_; m1 < mend; m1++) {
+ const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
+ GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
Scalar(1), -1, -1, 0, 0);
+
+ // We are done with the last task for the [m1, n1] block.
+ if (k + 1 == nk_) {
+ output_kernel_(output_mapper, tensor_contraction_params_,
+ m1 * bm_, n1 * bn_, bm(m1), bn(n1));
+ }
+ }
}
} else {
for (Index m1 = m * gm_; m1 < mend; m1++)
for (Index n1 = n * gn_; n1 < nend; n1++) {
- GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
- packed_lhs_[k % (P - 1)][m1],
+ const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
+ GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
Scalar(1), -1, -1, 0, 0);
+
+ // We are done with the last task for the [m1, n1] block.
+ if (k + 1 == nk_) {
+ output_kernel_(output_mapper, tensor_contraction_params_,
+ m1 * bm_, n1 * bn_, bm(m1), bn(n1));
+ }
}
}
signal_kernel(m, n, k + 1, false);
@@ -747,6 +764,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
#else // EIGEN_USE_SIMPLE_THREAD_POOL
+ // TODO(ezhulenev): SimpleThreadPool will be removed in the future, and seems
+ // like it's not worth adding output kernel support here.
+ static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
+ "SimpleThreadPool does not support contraction output kernels.");
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalProduct(Scalar* buffer) const {
@@ -1065,6 +1086,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
}
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
+ // TODO(ezhulenev): Add support for output kernels and LIBXSMM.
+ static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
+ "XSMM does not support contraction output kernels.");
+
template<int Alignment>
class ContextXsmm {
public: