diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-10 10:56:58 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2019-10-10 10:56:58 -0700 |
commit | a411e9f344a354673b72a490433cf3cc2148ddf1 (patch) | |
tree | 65d0e152a0cc6649ecb8b67c0579386475dbaf53 /unsupported/test | |
parent | b03eb63d7cb869cc4486ac393fad75fbcc36027f (diff) |
Block evaluation for TensorGenerator + TensorReverse + fixed bug in tensor reverse op
Diffstat (limited to 'unsupported/test')
-rw-r--r-- | unsupported/test/cxx11_tensor_block_eval.cpp | 44 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_executor.cpp | 20 |
2 files changed, 54 insertions, 10 deletions
diff --git a/unsupported/test/cxx11_tensor_block_eval.cpp b/unsupported/test/cxx11_tensor_block_eval.cpp index e11092af3..aac75014c 100644 --- a/unsupported/test/cxx11_tensor_block_eval.cpp +++ b/unsupported/test/cxx11_tensor_block_eval.cpp @@ -369,6 +369,48 @@ static void test_eval_tensor_chipping() { [&chipped_dims]() { return RandomBlock<Layout>(chipped_dims, 1, 10); }); } +template <typename T, int NumDims, int Layout> +static void test_eval_tensor_generator() { + DSizes<Index, NumDims> dims = RandomDims<NumDims>(10, 20); + Tensor<T, NumDims, Layout> input(dims); + input.setRandom(); + + auto generator = [](const array<Index, NumDims>& dims) -> T { + T result = static_cast<T>(0); + for (int i = 0; i < NumDims; ++i) { + result += static_cast<T>((i + 1) * dims[i]); + } + return result; + }; + + VerifyBlockEvaluator<T, NumDims, Layout>( + input.generate(generator), + [&dims]() { return FixedSizeBlock(dims); }); + + VerifyBlockEvaluator<T, NumDims, Layout>( + input.generate(generator), + [&dims]() { return RandomBlock<Layout>(dims, 1, 10); }); +} + +template <typename T, int NumDims, int Layout> +static void test_eval_tensor_reverse() { + DSizes<Index, NumDims> dims = RandomDims<NumDims>(10, 20); + Tensor<T, NumDims, Layout> input(dims); + input.setRandom(); + + // Randomly reverse dimensions. + Eigen::DSizes<bool, NumDims> reverse; + for (int i = 0; i < NumDims; ++i) reverse[i] = internal::random<bool>(); + + VerifyBlockEvaluator<T, NumDims, Layout>( + input.reverse(reverse), + [&dims]() { return FixedSizeBlock(dims); }); + + VerifyBlockEvaluator<T, NumDims, Layout>( + input.reverse(reverse), + [&dims]() { return RandomBlock<Layout>(dims, 1, 10); }); +} + template <typename T, int Layout> static void test_eval_tensor_reshape_with_bcast() { Index dim = internal::random<Index>(1, 100); @@ -573,6 +615,8 @@ EIGEN_DECLARE_TEST(cxx11_tensor_block_eval) { CALL_SUBTESTS_DIMS_LAYOUTS(test_eval_tensor_select); CALL_SUBTESTS_DIMS_LAYOUTS(test_eval_tensor_padding); CALL_SUBTESTS_DIMS_LAYOUTS(test_eval_tensor_chipping); + CALL_SUBTESTS_DIMS_LAYOUTS(test_eval_tensor_generator); + CALL_SUBTESTS_DIMS_LAYOUTS(test_eval_tensor_reverse); CALL_SUBTESTS_LAYOUTS(test_eval_tensor_reshape_with_bcast); CALL_SUBTESTS_LAYOUTS(test_eval_tensor_forced_eval); diff --git a/unsupported/test/cxx11_tensor_executor.cpp b/unsupported/test/cxx11_tensor_executor.cpp index 8fb4ba752..66f932746 100644 --- a/unsupported/test/cxx11_tensor_executor.cpp +++ b/unsupported/test/cxx11_tensor_executor.cpp @@ -539,7 +539,7 @@ static void test_execute_reverse_rvalue(Device d) // Reverse half of the dimensions. Eigen::array<bool, NumDims> reverse; - for (int i = 0; i < NumDims; ++i) reverse[i] = (dims[i] % 2 == 0); + for (int i = 0; i < NumDims; ++i) reverse[i] = internal::random<bool>(); const auto expr = src.reverse(reverse); @@ -756,16 +756,16 @@ EIGEN_DECLARE_TEST(cxx11_tensor_executor) { CALL_SUBTEST_COMBINATIONS_V2(12, test_execute_broadcasting_of_forced_eval, float, 4); CALL_SUBTEST_COMBINATIONS_V2(12, test_execute_broadcasting_of_forced_eval, float, 5); - CALL_SUBTEST_COMBINATIONS_V1(13, test_execute_generator_op, float, 2); - CALL_SUBTEST_COMBINATIONS_V1(13, test_execute_generator_op, float, 3); - CALL_SUBTEST_COMBINATIONS_V1(13, test_execute_generator_op, float, 4); - CALL_SUBTEST_COMBINATIONS_V1(13, test_execute_generator_op, float, 5); + CALL_SUBTEST_COMBINATIONS_V2(13, test_execute_generator_op, float, 2); + CALL_SUBTEST_COMBINATIONS_V2(13, test_execute_generator_op, float, 3); + CALL_SUBTEST_COMBINATIONS_V2(13, test_execute_generator_op, float, 4); + CALL_SUBTEST_COMBINATIONS_V2(13, test_execute_generator_op, float, 5); - CALL_SUBTEST_COMBINATIONS_V1(14, test_execute_reverse_rvalue, float, 1); - CALL_SUBTEST_COMBINATIONS_V1(14, test_execute_reverse_rvalue, float, 2); - CALL_SUBTEST_COMBINATIONS_V1(14, test_execute_reverse_rvalue, float, 3); - CALL_SUBTEST_COMBINATIONS_V1(14, test_execute_reverse_rvalue, float, 4); - CALL_SUBTEST_COMBINATIONS_V1(14, test_execute_reverse_rvalue, float, 5); + CALL_SUBTEST_COMBINATIONS_V2(14, test_execute_reverse_rvalue, float, 1); + CALL_SUBTEST_COMBINATIONS_V2(14, test_execute_reverse_rvalue, float, 2); + CALL_SUBTEST_COMBINATIONS_V2(14, test_execute_reverse_rvalue, float, 3); + CALL_SUBTEST_COMBINATIONS_V2(14, test_execute_reverse_rvalue, float, 4); + CALL_SUBTEST_COMBINATIONS_V2(14, test_execute_reverse_rvalue, float, 5); CALL_ASYNC_SUBTEST_COMBINATIONS(15, test_async_execute_unary_expr, float, 3); CALL_ASYNC_SUBTEST_COMBINATIONS(15, test_async_execute_unary_expr, float, 4); |