aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-04-12 13:35:10 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-04-12 13:35:10 -0700
commit039ee521250eab33e9f7aadc5ba2baef9661673c (patch)
treedf91e39821422f793c016c39f2f06f0417a70917 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent9a3f06d836dd40ab243521fc3a87425563e2aa11 (diff)
Tweak cost model for tensor contraction when parallelizing over the inner dimension.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h6
1 files changed, 3 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index caa8d1767..500f63e60 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -1169,7 +1169,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
// Compute cost.
const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
- TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n);
+ TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n, true, output_packet_size);
// Output stores.
cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m;
@@ -1192,8 +1192,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
int num_threads = 1;
double min_cost = total_parallel_cost;
double kPerThreadOverHead = 4000;
- double kFixedOverHead = 100000;
- for (int nt = 2; nt <= this->m_device.numThreads(); nt++) {
+ double kFixedOverHead = 50000;
+ for (int nt = 2; nt <= this->m_device.numThreads(); nt += 2) {
double sequential_cost =
kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
double parallel_cost = total_parallel_cost / nt + sequential_cost;