aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-06-09 08:25:22 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-06-09 08:25:22 -0700
commit14a112ee153f7d8554c3a8848e8a0461c7b82f13 (patch)
tree762b66ca2203bfd720fa7d8115eebdc9724fde80 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent8f92c26319a6e06cce6d0ba2a252521c8096a2c0 (diff)
Use signed integers more consistently to encode the number of threads to use to evaluate a tensor expression.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h12
1 files changed, 6 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index a60a17049..ee16cde9b 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -202,7 +202,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// across k dimension.
const TensorOpCost cost =
contractionCost(m, n, bm, bn, bk, shard_by_col, false);
- Index num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
+ int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
static_cast<double>(n) * m, cost, this->m_device.numThreads());
// TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
@@ -301,7 +301,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
class Context {
public:
Context(const Device& device, int num_threads, LhsMapper& lhs,
- RhsMapper& rhs, Scalar* buffer, Index m, Index n, Index k, Index bm,
+ RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
Index gn, Index nm0, Index nn0, bool shard_by_col,
bool parallel_pack)
@@ -309,13 +309,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
lhs_(lhs),
rhs_(rhs),
buffer_(buffer),
- output_(buffer, m),
+ output_(buffer, tm),
num_threads_(num_threads),
shard_by_col_(shard_by_col),
parallel_pack_(parallel_pack),
- m_(m),
- n_(n),
- k_(k),
+ m_(tm),
+ n_(tn),
+ k_(tk),
bm_(bm),
bn_(bn),
bk_(bk),