aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_executor.cpp
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-27 12:45:17 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-07-27 12:45:17 -0700
commit966c2a7bb62a8b5b9ecd349730ffcd3b5719837d (patch)
tree83e61bb77a5340f529c336afaa69cc78d654d599 /unsupported/test/cxx11_tensor_executor.cpp
parent6913221c43c6ad41b1fbfc0d263d2764abd11ad2 (diff)
Rename Index to StorageIndex + use Eigen::Array and Eigen::Map when possible
Diffstat (limited to 'unsupported/test/cxx11_tensor_executor.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_executor.cpp20
1 files changed, 13 insertions, 7 deletions
diff --git a/unsupported/test/cxx11_tensor_executor.cpp b/unsupported/test/cxx11_tensor_executor.cpp
index 5ae45ac5b..274f901ce 100644
--- a/unsupported/test/cxx11_tensor_executor.cpp
+++ b/unsupported/test/cxx11_tensor_executor.cpp
@@ -13,7 +13,6 @@
#include <Eigen/CXX11/Tensor>
-using Eigen::Index;
using Eigen::Tensor;
using Eigen::RowMajor;
using Eigen::ColMajor;
@@ -25,9 +24,16 @@ template <typename Device, bool Vectorizable, bool Tileable, int Layout>
static void test_execute_binary_expr(Device d) {
// Pick a large enough tensor size to bypass small tensor block evaluation
// optimization.
- Tensor<float, 3> lhs(840, 390, 37);
- Tensor<float, 3> rhs(840, 390, 37);
- Tensor<float, 3> dst(840, 390, 37);
+ int d0 = internal::random<int>(100, 200);
+ int d1 = internal::random<int>(100, 200);
+ int d2 = internal::random<int>(100, 200);
+
+ static constexpr int Options = 0;
+ using IndexType = int;
+
+ Tensor<float, 3, Options, IndexType> lhs(d0, d1, d2);
+ Tensor<float, 3, Options, IndexType> rhs(d0, d1, d2);
+ Tensor<float, 3, Options, IndexType> dst(d0, d1, d2);
lhs.setRandom();
rhs.setRandom();
@@ -40,9 +46,9 @@ static void test_execute_binary_expr(Device d) {
Executor::run(Assign(dst, expr), d);
- for (int i = 0; i < 840; ++i) {
- for (int j = 0; j < 390; ++j) {
- for (int k = 0; k < 37; ++k) {
+ for (int i = 0; i < d0; ++i) {
+ for (int j = 0; j < d1; ++j) {
+ for (int k = 0; k < d2; ++k) {
float sum = lhs(i, j, k) + rhs(i, j, k);
VERIFY_IS_EQUAL(sum, dst(i, j, k));
}