diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 47 |
1 files changed, 22 insertions, 25 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 70eec78fe..4553c3785 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -124,14 +124,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // Again, we don't know number of threads yet, so we use 2. Index bm, bn, bk; if (shard_by_col) { - internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, + internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 2); bm = blocking.mc(); bn = blocking.nc(); bk = blocking.kc(); } else { - internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, + internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(k, m, n, 2); bm = blocking.mc(); @@ -169,14 +169,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // Now that we know number of threads, recalculate sharding and blocking. shard_by_col = shardByCol(m, n, num_threads); if (shard_by_col) { - internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, + internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, num_threads); bm = blocking.mc(); bn = blocking.nc(); bk = blocking.kc(); } else { - internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, + internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(k, m, n, num_threads); bm = blocking.mc(); @@ -250,17 +250,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT contract_t, internal::packet_traits<RhsScalar>::size, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned> RhsMapper; - typedef internal::gemm_pack_lhs< - LhsScalar, Index, typename LhsMapper::SubMapper, Traits::mr, - Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor> - LhsPacker; - typedef internal::gemm_pack_rhs< - RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> - RhsPacker; + typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; - typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, - Traits::mr, Traits::nr, false, false> - GebpKernel; + + typedef internal::TensorContractionKernel< + Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper> + TensorContractionKernel; Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, @@ -442,8 +437,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT void pack_lhs(Index m, Index k) { const Index mend = m * gm_ + gm(m); for (Index m1 = m * gm_; m1 < mend; m1++) - LhsPacker()(packed_lhs_[k % (P - 1)][m1], - lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1)); + TensorContractionKernel::packLhs(packed_lhs_[k % (P - 1)][m1], + lhs_.getSubMapper(m1 * bm_, k * bk_), + bk(k), bm(m1)); if (!parallel_pack_ && shard_by_col_) { signal_packing(k); @@ -466,8 +462,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // deadlocks. memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar)); } - RhsPacker()(packed_rhs_[k % (P - 1)][n1], - rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1)); + TensorContractionKernel::packRhs(packed_rhs_[k % (P - 1)][n1], + rhs_.getSubMapper(k * bk_, n1 * bn_), + bk(k), bn(n1)); } if (parallel_pack_ || shard_by_col_) { @@ -488,9 +485,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index m1 = m * gm_; m1 < mend; m1++) { const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); - GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1], - packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), - Scalar(1), -1, -1, 0, 0); + TensorContractionKernel::invoke( + output_mapper, packed_lhs_[k % (P - 1)][m1], + packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), Scalar(1)); // We are done with the last task for the [m1, n1] block. if (k + 1 == nk_) { @@ -503,9 +500,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT for (Index m1 = m * gm_; m1 < mend; m1++) for (Index n1 = n * gn_; n1 < nend; n1++) { const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); - GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1], - packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), - Scalar(1), -1, -1, 0, 0); + TensorContractionKernel::invoke( + output_mapper, packed_lhs_[k % (P - 1)][m1], + packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), Scalar(1)); // We are done with the last task for the [m1, n1] block. if (k + 1 == nk_) { @@ -763,7 +760,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // (contraction) dimension (k). static bool shardByInnerDim(Index m, Index n, Index k, int num_threads, int num_threads_by_k) { - size_t bufsize = m * n * sizeof(Scalar); + std::ptrdiff_t bufsize = m * n * sizeof(Scalar); bool shard_by_k = false; if (n == 1 || // If mat*vec or... num_threads_by_k < 2 || // running single threaded or... |