diff options
Diffstat (limited to 'unsupported/test/cxx11_tensor_scan.cpp')
-rw-r--r-- | unsupported/test/cxx11_tensor_scan.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/unsupported/test/cxx11_tensor_scan.cpp b/unsupported/test/cxx11_tensor_scan.cpp index dbd3023d7..bafa6c96e 100644 --- a/unsupported/test/cxx11_tensor_scan.cpp +++ b/unsupported/test/cxx11_tensor_scan.cpp @@ -39,6 +39,30 @@ static void test_1d_scan() } template <int DataLayout, typename Type=float> +static void test_1d_inclusive_scan() +{ + int size = 50; + Tensor<Type, 1, DataLayout> tensor(size); + tensor.setRandom(); + Tensor<Type, 1, DataLayout> result = tensor.cumsum(0, true); + + VERIFY_IS_EQUAL(tensor.dimension(0), result.dimension(0)); + + float accum = 0; + for (int i = 0; i < size; i++) { + VERIFY_IS_EQUAL(result(i), accum); + accum += tensor(i); + } + + accum = 1; + result = tensor.cumprod(0, true); + for (int i = 0; i < size; i++) { + VERIFY_IS_EQUAL(result(i), accum); + accum *= tensor(i); + } +} + +template <int DataLayout, typename Type=float> static void test_4d_scan() { int size = 5; |