From 6190aa5632fb698fa66d2dad2949275089f15738 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Mon, 9 Jul 2018 11:23:16 +0200 Subject: bug #1567: add optimized path for tensor broadcasting and 'Channel First' shape --- unsupported/test/cxx11_tensor_broadcasting.cpp | 57 ++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) (limited to 'unsupported/test/cxx11_tensor_broadcasting.cpp') diff --git a/unsupported/test/cxx11_tensor_broadcasting.cpp b/unsupported/test/cxx11_tensor_broadcasting.cpp index a9d268ea6..f0ff03184 100644 --- a/unsupported/test/cxx11_tensor_broadcasting.cpp +++ b/unsupported/test/cxx11_tensor_broadcasting.cpp @@ -238,6 +238,59 @@ static void test_simple_broadcasting_n_by_one() } } +template +static void test_simple_broadcasting_one_by_n_by_one_1d() +{ + Tensor tensor(1,7,1); + tensor.setRandom(); + array broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 13; + Tensor broadcasted; + broadcasted = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcasted.dimension(0), 5); + VERIFY_IS_EQUAL(broadcasted.dimension(1), 7); + VERIFY_IS_EQUAL(broadcasted.dimension(2), 13); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k)); + } + } + } +} + +template +static void test_simple_broadcasting_one_by_n_by_one_2d() +{ + Tensor tensor(1,7,13,1); + tensor.setRandom(); + array broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 5); + VERIFY_IS_EQUAL(broadcast.dimension(1), 7); + VERIFY_IS_EQUAL(broadcast.dimension(2), 13); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l)); + } + } + } + } +} void test_cxx11_tensor_broadcasting() { @@ -253,4 +306,8 @@ void test_cxx11_tensor_broadcasting() CALL_SUBTEST(test_simple_broadcasting_n_by_one()); CALL_SUBTEST(test_simple_broadcasting_one_by_n()); CALL_SUBTEST(test_simple_broadcasting_n_by_one()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d()); } -- cgit v1.2.3