aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2018-07-31 22:38:28 +0000
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2018-07-31 22:38:28 +0000
commitedf46bd7a27ef1088efc2116196c088d59d22b4a (patch)
tree5b5c5f6a0634c7d3b6015ec835f20bdede9c6ca5 /unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
parent678a0dcb12d55e1d85aade7b34c706b2a5d2d49e (diff)
parent1eff6cf8a77f1b8699671d31f8f307a6fd9170ea (diff)
Merged in yuefengz/eigen (pull request PR-370)
Use device's allocate function instead of internal::aligned_malloc.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h4
1 files changed, 2 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 182c5f7f9..1d145c4b1 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -317,7 +317,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
divup<size_t>(bm_ * bk_ * sizeof(LhsScalar), align) * align;
size_t rhs_size =
divup<size_t>(bn_ * bk_ * sizeof(RhsScalar), align) * align;
- packed_mem_ = static_cast<char*>(internal::aligned_malloc(
+ packed_mem_ = static_cast<char*>(device_.allocate(
(nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
char* mem = static_cast<char*>(packed_mem_);
for (Index x = 0; x < numext::mini<Index>(nk_, P - 1); x++) {
@@ -339,7 +339,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
delete[] state_kernel_[x];
}
- internal::aligned_free(packed_mem_);
+ device_.deallocate(packed_mem_);
}
void run() {