diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/select_and_scatter_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/select_and_scatter_test.cc | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc new file mode 100644 index 0000000000..fb1effc8c4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -0,0 +1,395 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests the select-and-scatter XLA operation. + +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SelectAndScatterTest : public ClientLibraryTestBase { + public: + SelectAndScatterTest() : builder_(client_, TestName()) { + // Create S32 GE and ADD computations for select and scatter respectively. + ge_s32_ = CreateScalarGeComputation(S32, &builder_); + add_s32_ = CreateScalarAddComputation(S32, &builder_); + ge_f32_ = CreateScalarGeComputation(F32, &builder_); + add_f32_ = CreateScalarAddComputation(F32, &builder_); + max_f32_ = CreateScalarMaxComputation(F32, &builder_); + min_f32_ = CreateScalarMinComputation(F32, &builder_); + } + + ComputationBuilder builder_; + Computation ge_s32_; + Computation add_s32_; + Computation ge_f32_; + Computation add_f32_; + Computation max_f32_; + Computation min_f32_; +}; + +// Test for F32 1D array, with a zero-element input. +XLA_TEST_F(SelectAndScatterTest, R1S0F32) { + const auto operand = builder_.ConstantR1<float>({}); + const auto source = builder_.ConstantR1<float>({}); + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7)); +} + +// Test for F32 1D array, when windows do not overlap. +XLA_TEST_F(SelectAndScatterTest, R1F32) { + const auto operand = + builder_.ConstantR1<float>({1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); + const auto source = builder_.ConstantR1<float>({34.f, 42.f}); + const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f}; + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +// Test for S32 1D array, when windows do not overlap and the init value is 1. +XLA_TEST_F(SelectAndScatterTest, R1S32) { + const auto operand = builder_.ConstantR1<int32>({-1, 0, 6, 4, -4, 10}); + const auto source = builder_.ConstantR1<int32>({-10, 20}); + const std::vector<int32> expected = {1, 1, -9, 1, 1, 21}; + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + builder_.ConstantR0<int32>(1), add_s32_); + ComputeAndCompareR1<int32>(&builder_, expected, {}); +} + +// Test for S32 1D array, when windows overlap with each other. +XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { + const auto operand = builder_.ConstantR1<int32>({1, 9, 3, 7, 5, 6}); + const auto source = builder_.ConstantR1<int32>({34, 42, 53, 19}); + const std::vector<int32> expected = {0, 76, 0, 72, 0, 0}; + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{1}, Padding::kValid, source, + builder_.ConstantR0<int32>(0), add_s32_); + ComputeAndCompareR1<int32>(&builder_, expected, {}); +} + +// Test for S32 2D array, when windows do not overlap. +XLA_TEST_F(SelectAndScatterTest, R2S32) { + const auto operand = + builder_.ConstantR2<int32>({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); + const auto source = builder_.ConstantR2<int32>({{2, 6}}); + Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + builder_.ConstantR0<int32>(0), add_s32_); + ComputeAndCompareR2<int32>(&builder_, expected, {}); +} + +// Similar to SelectAndScatterTest.R2S32 but the input is transposed. +XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { + const auto operand = builder_.ConstantR2<int32>( + {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); + const auto reshape = + builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); + const auto source = builder_.ConstantR2<int32>({{2, 6}}); + Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); + builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + builder_.ConstantR0<int32>(0), add_s32_); + ComputeAndCompareR2<int32>(&builder_, expected, {}); +} + +// Test for S32 2D array, when windows overlap with each other. +XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) { + const auto operand = + builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = builder_.ConstantR2<int32>({{2, 6, 4}}); + Array2D<int32> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + builder_.ConstantR0<int32>(0), add_s32_); + ComputeAndCompareR2<int32>(&builder_, expected, {}); +} + +// Test for S32 2D array, when the padding is Padding::kSAME. +XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { + const auto operand = + builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = builder_.ConstantR2<int32>({{2, 6, 4}}); + Array2D<int32> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{2, 2}, Padding::kSame, source, + builder_.ConstantR0<int32>(0), add_s32_); + ComputeAndCompareR2<int32>(&builder_, expected, {}); +} + +// Test for S32 2D array, when the padding is Padding::kSAME and windows overlap +// with each other. +XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) { + const auto operand = + builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = + builder_.ConstantR2<int32>({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); + Array2D<int32> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}}); + builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + builder_.ConstantR0<int32>(0), add_s32_); + ComputeAndCompareR2<int32>(&builder_, expected, {}); +} + +XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) { + const auto operand = builder_.ConstantR2<float>( + {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); + const auto source = builder_.ConstantR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}); + Array2D<float> expected( + {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}}); + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32Valid) { + Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f}, + {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}}; + Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f}, + {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}}; + Array4D<float> o(4, 6, 15, 220); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D<float> e(4, 6, 15, 220); + e.FillWithPZ(pze); + Array4D<float> s(2, 2, 15, 220); + s.FillWithPZ(pzs); + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32Overlap) { + Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.0f}, + {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}}; + Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 8.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 3.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; + Array4D<float> o(4, 5, 17, 128); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D<float> e(4, 5, 17, 128); + e.FillWithPZ(pze); + Array4D<float> s(2, 2, 17, 128); + s.FillWithPZ(pzs); + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32OverlapSmall) { + Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.0f}, + {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}}; + Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 8.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 3.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; + Array4D<float> o(4, 5, 1, 1); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D<float> e(4, 5, 1, 1); + e.FillWithPZ(pze); + Array4D<float> s(2, 2, 1, 1); + s.FillWithPZ(pzs); + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) { + // This test is testing the Reference Util + Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f}, + {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f}, + {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f}, + {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}}; + Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; + Array4D<float> o(4, 6, 4, 4); + o.FillWithPZ(pzo); + auto operand = builder_.ConstantR4FromArray4D(o); + Array4D<float> s(2, 2, 4, 4); + s.FillWithPZ(pzs); + + auto source = builder_.ConstantR4FromArray4D(s); + s.FillWithPZ(pzs); + builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1}, + {2, 3, 1, 1}, false); + ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefSameRandom) { + Array4D<float> o(7, 7, 8, 256); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D<float> s(4, 4, 8, 256); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1}, + Padding::kSame, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 2, 1, 1}, + {2, 2, 1, 1}, true); + + ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefSameRandomFullyPadded) { + Array4D<float> o(1, 1, 5, 5); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D<float> s(1, 1, 5, 5); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1}, + Padding::kSame, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1}, + {3, 3, 1, 1}, true); + + ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefValidRandom) { + Array4D<float> o(9, 9, 16, 128); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D<float> s(3, 3, 16, 128); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1}, + {3, 3, 1, 1}, false); + + ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +TEST_F(SelectAndScatterTest, R4F32RefValidRandomSmall) { + Array4D<float> o(3, 3, 4, 4); + o.FillRandom(1.5f); + auto operand = builder_.ConstantR4FromArray4D(o); + + Array4D<float> s(1, 1, 4, 4); + s.FillRandom(12.0f); + auto source = builder_.ConstantR4FromArray4D(s); + + builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1}, + Padding::kValid, source, + builder_.ConstantR0<float>(0.0f), add_f32_); + + auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1}, + {3, 3, 1, 1}, false); + + ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7)); +} + +XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { + const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1}); + const auto source = builder_.ConstantR1<float>({34, 42, 53, 19}); + const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0}; + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + builder_.ConstantR0<float>(0), max_f32_); + ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { + const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1}); + const auto source = builder_.ConstantR1<float>({34, 42, 53, 19}); + const float max_float = std::numeric_limits<float>::max(); + const std::vector<float> expected = {max_float, max_float, max_float, 19, + max_float, max_float, max_float}; + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + builder_.ConstantR0<float>(max_float), min_f32_); + ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7)); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector<tensorflow::Flag> flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} |