aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-30 15:13:38 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-08-30 15:13:38 -0700
commitf0b36fb9a405400e82b73ea70097b8ae3cd1095a (patch)
treed3a2903422799257720d2d4989bcd845ab2ae27e /unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
parent619cea94916e7531a839ee0ff657714857921db8 (diff)
evalSubExprsIfNeededAsync + async TensorContractionThreadPool
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h139
1 files changed, 95 insertions, 44 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index a398b2b3f..2f8656fbb 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -373,13 +373,13 @@ struct TensorContractionEvaluatorBase
typedef typename Storage::Type EvaluatorPointerType;
enum {
- IsAligned = true,
- PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
- BlockAccess = false,
+ IsAligned = true,
+ PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
+ BlockAccess = false,
PreferBlockAccess = false,
- Layout = TensorEvaluator<LeftArgType, Device>::Layout,
- CoordAccess = false, // to be implemented
- RawAccess = true
+ Layout = TensorEvaluator<LeftArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
+ RawAccess = true
};
// Most of the code is assuming that both input tensors are ColMajor. If the
@@ -390,7 +390,7 @@ struct TensorContractionEvaluatorBase
static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
typedef typename internal::conditional<
static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
-
+
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluatorType;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluatorType;
@@ -605,48 +605,99 @@ struct TensorContractionEvaluatorBase
}
}
-#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
- if (this->m_lhs_inner_dim_contiguous) { \
- if (this->m_rhs_inner_dim_contiguous) { \
- if (this->m_rhs_inner_dim_reordered) { \
- METHOD<true, true, true, ALIGNMENT>ARGS; \
- } \
- else { \
- METHOD<true, true, false, ALIGNMENT>ARGS; \
- } \
- } \
- else { \
- if (this->m_rhs_inner_dim_reordered) { \
- METHOD<true, false, true, ALIGNMENT>ARGS; \
- } \
- else { \
- METHOD<true, false, false, ALIGNMENT>ARGS; \
- } \
- } \
- } \
- else { \
- if (this->m_rhs_inner_dim_contiguous) { \
- if (this->m_rhs_inner_dim_reordered) { \
- METHOD<false, true, true, ALIGNMENT>ARGS; \
- } \
- else { \
- METHOD<false, true, false, ALIGNMENT>ARGS; \
- } \
- } \
- else { \
- if (this->m_rhs_inner_dim_reordered) { \
- METHOD<false, false, true, ALIGNMENT>ARGS; \
- } \
- else { \
- METHOD<false, false, false, ALIGNMENT>ARGS; \
- } \
- } \
- }
+#ifdef EIGEN_USE_THREADS
+ template <typename EvalSubExprsCallback>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
+ EvaluatorPointerType dest, EvalSubExprsCallback done) {
+ m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done, dest](bool) {
+ m_rightImpl.evalSubExprsIfNeededAsync(nullptr, [this, done, dest](bool) {
+ if (dest) {
+ evalToAsync(dest, [done]() { done(false); });
+ } else {
+ m_result = static_cast<EvaluatorPointerType>(
+ m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
+ evalToAsync(m_result, [done]() { done(true); });
+ }
+ });
+ });
+ }
+#endif // EIGEN_USE_THREADS
+
+#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
+ if (this->m_lhs_inner_dim_contiguous) { \
+ if (this->m_rhs_inner_dim_contiguous) { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ METHOD<true, true, true, ALIGNMENT> ARGS; \
+ } else { \
+ METHOD<true, true, false, ALIGNMENT> ARGS; \
+ } \
+ } else { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ METHOD<true, false, true, ALIGNMENT> ARGS; \
+ } else { \
+ METHOD<true, false, false, ALIGNMENT> ARGS; \
+ } \
+ } \
+ } else { \
+ if (this->m_rhs_inner_dim_contiguous) { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ METHOD<false, true, true, ALIGNMENT> ARGS; \
+ } else { \
+ METHOD<false, true, false, ALIGNMENT> ARGS; \
+ } \
+ } else { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ METHOD<false, false, true, ALIGNMENT> ARGS; \
+ } else { \
+ METHOD<false, false, false, ALIGNMENT> ARGS; \
+ } \
+ } \
+ }
+
+#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
+ if (this->m_lhs_inner_dim_contiguous) { \
+ if (this->m_rhs_inner_dim_contiguous) { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \
+ } else { \
+ (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \
+ } \
+ } else { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \
+ } else { \
+ (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \
+ } \
+ } \
+ } else { \
+ if (this->m_rhs_inner_dim_contiguous) { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \
+ } else { \
+ (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \
+ } \
+ } else { \
+ if (this->m_rhs_inner_dim_reordered) { \
+ (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \
+ } else { \
+ (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \
+ } \
+ } \
+ }
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer);
}
+#ifdef EIGEN_USE_THREADS
+ template <typename EvalToCallback>
+ void evalToAsync(Scalar* buffer, EvalToCallback done) const {
+ static_cast<const Derived*>(this)
+ ->template evalProductAsync<EvalToCallback, Unaligned>(buffer,
+ std::move(done));
+ }
+#endif // EIGEN_USE_THREADS
+
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered, int Alignment>
void evalProductSequential(Scalar* buffer) const {