From 6293ad3f392a7b97ebb9f9f874682505c1391f2d Mon Sep 17 00:00:00 2001 From: Vamsi Sripathi Date: Wed, 23 May 2018 14:02:05 -0700 Subject: Performance improvements to tensor broadcast operation 1. Added new packet functions using SIMD for NByOne, OneByN cases 2. Modified existing packet functions to reduce index calculations when input stride is non-SIMD 3. Added 4 test cases to cover the new packet functions --- unsupported/test/cxx11_tensor_broadcasting.cpp | 62 ++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) (limited to 'unsupported/test/cxx11_tensor_broadcasting.cpp') diff --git a/unsupported/test/cxx11_tensor_broadcasting.cpp b/unsupported/test/cxx11_tensor_broadcasting.cpp index 5c0ea5889..a9d268ea6 100644 --- a/unsupported/test/cxx11_tensor_broadcasting.cpp +++ b/unsupported/test/cxx11_tensor_broadcasting.cpp @@ -180,6 +180,64 @@ static void test_fixed_size_broadcasting() #endif } +template +static void test_simple_broadcasting_one_by_n() +{ + Tensor tensor(1,13,5,7); + tensor.setRandom(); + array broadcasts; + broadcasts[0] = 9; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 1; + Tensor broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 9); + VERIFY_IS_EQUAL(broadcast.dimension(1), 13); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 7); + + for (int i = 0; i < 9; ++i) { + for (int j = 0; j < 13; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l)); + } + } + } + } +} + +template +static void test_simple_broadcasting_n_by_one() +{ + Tensor tensor(7,3,5,1); + tensor.setRandom(); + array broadcasts; + broadcasts[0] = 1; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 7); + VERIFY_IS_EQUAL(broadcast.dimension(1), 3); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l)); + } + } + } + } +} + void test_cxx11_tensor_broadcasting() { @@ -191,4 +249,8 @@ void test_cxx11_tensor_broadcasting() CALL_SUBTEST(test_static_broadcasting()); CALL_SUBTEST(test_fixed_size_broadcasting()); CALL_SUBTEST(test_fixed_size_broadcasting()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one()); } -- cgit v1.2.3