diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-09-19 13:54:49 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2019-09-19 13:54:49 -0700 |
commit | 1d5af0693c4d54cf15aa9a787d5765ddfaf706dc (patch) | |
tree | 6cb85b67b43aa62520e9b93cabb70cc19ba92f64 /unsupported/test/cxx11_tensor_thread_pool.cpp | |
parent | 28b6786498cb7ad183744f4ac4b3734256d35125 (diff) |
Add support for asynchronous evaluation of tensor casting expressions.
Diffstat (limited to 'unsupported/test/cxx11_tensor_thread_pool.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_thread_pool.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index dae7b0335..b772a1d60 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -40,19 +40,19 @@ void test_multithread_elementwise() { Tensor<float, 3> in1(200, 30, 70); Tensor<float, 3> in2(200, 30, 70); - Tensor<float, 3> out(200, 30, 70); + Tensor<double, 3> out(200, 30, 70); in1.setRandom(); in2.setRandom(); Eigen::ThreadPool tp(internal::random<int>(3, 11)); Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11)); - out.device(thread_pool_device) = in1 + in2 * 3.14f; + out.device(thread_pool_device) = (in1 + in2 * 3.14f).cast<double>(); for (int i = 0; i < 200; ++i) { for (int j = 0; j < 30; ++j) { for (int k = 0; k < 70; ++k) { - VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f); + VERIFY_IS_APPROX(out(i, j, k), static_cast<double>(in1(i, j, k) + in2(i, j, k) * 3.14f)); } } } @@ -62,7 +62,7 @@ void test_async_multithread_elementwise() { Tensor<float, 3> in1(200, 30, 70); Tensor<float, 3> in2(200, 30, 70); - Tensor<float, 3> out(200, 30, 70); + Tensor<double, 3> out(200, 30, 70); in1.setRandom(); in2.setRandom(); @@ -71,13 +71,13 @@ void test_async_multithread_elementwise() Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11)); Eigen::Barrier b(1); - out.device(thread_pool_device, [&b]() { b.Notify(); }) = in1 + in2 * 3.14f; + out.device(thread_pool_device, [&b]() { b.Notify(); }) = (in1 + in2 * 3.14f).cast<double>(); b.Wait(); for (int i = 0; i < 200; ++i) { for (int j = 0; j < 30; ++j) { for (int k = 0; k < 70; ++k) { - VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f); + VERIFY_IS_APPROX(out(i, j, k), static_cast<double>(in1(i, j, k) + in2(i, j, k) * 3.14f)); } } } |