diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index dc0513305..8e4c7c11d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -152,7 +152,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT Index mc = m; Index nc = n; Index kc = k; - internal::computeProductBlockingSizes<LhsScalar,RhsScalar,1>(kc, mc, nc/*, num_threads*/); + internal::computeProductBlockingSizes<LhsScalar,RhsScalar,1>(kc, mc, nc, num_threads); eigen_assert(mc <= m); eigen_assert(nc <= n); eigen_assert(kc <= k); @@ -197,9 +197,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT // this should really be numBlockAs * n_blocks; const Index num_kernel_promises = num_threads * n_blocks; - Promise p; - p.set_value(); - std::vector<Promise> kernel_promises(num_kernel_promises, p); + std::vector<Promise> kernel_promises(num_kernel_promises); + for (int i = 0; i < kernel_promises.size(); ++i) { + kernel_promises[i].set_value(); + } for (Index k_block_idx = 0; k_block_idx < k_blocks; k_block_idx++) { const Index k_start = k_block_idx * kc; @@ -275,8 +276,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT need_to_pack, // need_to_pack }; - typedef decltype(Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>) Func; - this->m_device.enqueueNoFuture<Func, packRKArg>(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>, arg); + this->m_device.enqueueNoFuture(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>, arg); } } } @@ -338,7 +338,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT actual_mc, arg.kc, arg.nc, 1.0, -1, -1, 0, 0); const Index set_idx = blockAId * arg.n_blocks + arg.n_block_idx; - eigen_assert(!(*arg.kernel_promises)[set_idx].ready()); (*arg.kernel_promises)[set_idx].set_value(); } } |