aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-09-19 13:54:49 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2019-09-19 13:54:49 -0700
commit1d5af0693c4d54cf15aa9a787d5765ddfaf706dc (patch)
tree6cb85b67b43aa62520e9b93cabb70cc19ba92f64
parent28b6786498cb7ad183744f4ac4b3734256d35125 (diff)
Add support for asynchronous evaluation of tensor casting expressions.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h31
-rw-r--r--unsupported/test/cxx11_tensor_thread_pool.cpp12
2 files changed, 37 insertions, 6 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
index e96f31537..fa329bfe6 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
@@ -178,6 +178,27 @@ template <typename Eval, typename EvalPointerType> struct ConversionSubExprEval<
}
};
+#ifdef EIGEN_USE_THREADS
+template <bool SameType, typename Eval, typename EvalPointerType,
+ typename EvalSubExprsCallback>
+struct ConversionSubExprEvalAsync {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(
+ Eval& impl, EvalPointerType, EvalSubExprsCallback done) {
+ impl.evalSubExprsIfNeededAsync(nullptr, std::move(done));
+ }
+};
+
+template <typename Eval, typename EvalPointerType,
+ typename EvalSubExprsCallback>
+struct ConversionSubExprEvalAsync<true, Eval, EvalPointerType,
+ EvalSubExprsCallback> {
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(
+ Eval& impl, EvalPointerType data, EvalSubExprsCallback done) {
+ impl.evalSubExprsIfNeededAsync(data, std::move(done));
+ }
+};
+#endif
+
namespace internal {
template <typename SrcType, typename TargetType, bool IsSameT>
@@ -299,6 +320,16 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device>
return ConversionSubExprEval<IsSameType, TensorEvaluator<ArgType, Device>, EvaluatorPointerType>::run(m_impl, data);
}
+#ifdef EIGEN_USE_THREADS
+ template <typename EvalSubExprsCallback>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
+ EvaluatorPointerType data, EvalSubExprsCallback done) {
+ ConversionSubExprEvalAsync<IsSameType, TensorEvaluator<ArgType, Device>,
+ EvaluatorPointerType,
+ EvalSubExprsCallback>::run(m_impl, data, std::move(done));
+ }
+#endif
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup()
{
m_impl.cleanup();
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp
index dae7b0335..b772a1d60 100644
--- a/unsupported/test/cxx11_tensor_thread_pool.cpp
+++ b/unsupported/test/cxx11_tensor_thread_pool.cpp
@@ -40,19 +40,19 @@ void test_multithread_elementwise()
{
Tensor<float, 3> in1(200, 30, 70);
Tensor<float, 3> in2(200, 30, 70);
- Tensor<float, 3> out(200, 30, 70);
+ Tensor<double, 3> out(200, 30, 70);
in1.setRandom();
in2.setRandom();
Eigen::ThreadPool tp(internal::random<int>(3, 11));
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
- out.device(thread_pool_device) = in1 + in2 * 3.14f;
+ out.device(thread_pool_device) = (in1 + in2 * 3.14f).cast<double>();
for (int i = 0; i < 200; ++i) {
for (int j = 0; j < 30; ++j) {
for (int k = 0; k < 70; ++k) {
- VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f);
+ VERIFY_IS_APPROX(out(i, j, k), static_cast<double>(in1(i, j, k) + in2(i, j, k) * 3.14f));
}
}
}
@@ -62,7 +62,7 @@ void test_async_multithread_elementwise()
{
Tensor<float, 3> in1(200, 30, 70);
Tensor<float, 3> in2(200, 30, 70);
- Tensor<float, 3> out(200, 30, 70);
+ Tensor<double, 3> out(200, 30, 70);
in1.setRandom();
in2.setRandom();
@@ -71,13 +71,13 @@ void test_async_multithread_elementwise()
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
Eigen::Barrier b(1);
- out.device(thread_pool_device, [&b]() { b.Notify(); }) = in1 + in2 * 3.14f;
+ out.device(thread_pool_device, [&b]() { b.Notify(); }) = (in1 + in2 * 3.14f).cast<double>();
b.Wait();
for (int i = 0; i < 200; ++i) {
for (int j = 0; j < 30; ++j) {
for (int k = 0; k < 70; ++k) {
- VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f);
+ VERIFY_IS_APPROX(out(i, j, k), static_cast<double>(in1(i, j, k) + in2(i, j, k) * 3.14f));
}
}
}