diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils_test.cc | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index a2f0338e25..64d9e2031e 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) { TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); } +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get<int32>({0}), 0); + + EXPECT_GE(index_arg.Get<int32>({1}), 0); + EXPECT_LE(index_arg.Get<int32>({1}), 2); + + EXPECT_GE(index_arg.Get<int32>({2}), 0); + EXPECT_LE(index_arg.Get<int32>({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + update_param.1 = f32[1,2,3]{0,1,2} parameter(3) + update_param.2 = f32[3,2,2]{0,1,2} parameter(4) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 5); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get<int32>({0}), 0); + + EXPECT_GE(index_arg.Get<int32>({1}), 0); + EXPECT_LE(index_arg.Get<int32>({1}), 2); + + EXPECT_GE(index_arg.Get<int32>({2}), 0); + EXPECT_LE(index_arg.Get<int32>({2}), 3); +} + } // namespace } // namespace xla |