/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/test_utils.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { // A test fixture is used because we need a client for our computation builder. class TestUtilsTest : public LocalClientTestBase {}; XLA_TEST_F(TestUtilsTest, UnusedParam) { XlaBuilder builder(TestName()); // Make the reduction lambda. Shape single_float = ShapeUtil::MakeShape(F32, {}); Parameter(&builder, 0, single_float, "unused"); Parameter(&builder, 1, single_float, "used"); auto computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); // Make the reduction. Shape pair_float = ShapeUtil::MakeShape(F32, {2}); Reduce(Parameter(&builder, 0, pair_float, "operand"), Parameter(&builder, 1, single_float, "init"), computation_status.ValueOrDie(), {0}); computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); auto executable_status = local_client_->Compile( computation_status.ValueOrDie(), {&pair_float, &single_float}, ExecutableBuildOptions()); TF_ASSERT_OK(executable_status.status()); HloModule& module = const_cast( executable_status.ValueOrDie()->executable()->module()); TF_ASSERT_OK(MakeFakeArguments(&module).status()); } XLA_TEST_F(TestUtilsTest, Token) { auto module = ParseHloString( R"(HloModule outfeed_module ENTRY InfeedToOutfeed { token = token[] parameter(0) infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 outfeed = token[] outfeed(infeed.data, token) ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) })") .ValueOrDie(); TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { auto module = ParseHloString( R"(HloModule index_space_module ENTRY IndexSpace { index_param = s32[3]{0} parameter(0) array_param.1 = f32[123,4,789]{0,1,2} parameter(1) array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 3); const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); EXPECT_GE(index_arg.Get({1}), 0); EXPECT_LE(index_arg.Get({1}), 2); EXPECT_GE(index_arg.Get({2}), 0); EXPECT_LE(index_arg.Get({2}), 3); } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { auto module = ParseHloString( R"(HloModule index_space_module ENTRY IndexSpace { index_param = s32[3]{0} parameter(0) array_param.1 = f32[123,4,789]{0,1,2} parameter(1) array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) update_param.1 = f32[1,2,3]{0,1,2} parameter(3) update_param.2 = f32[3,2,2]{0,1,2} parameter(4) dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); EXPECT_GE(index_arg.Get({1}), 0); EXPECT_LE(index_arg.Get({1}), 2); EXPECT_GE(index_arg.Get({2}), 0); EXPECT_LE(index_arg.Get({2}), 3); } XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { // Inputs which are sort keys in key/value sorts should have no duplicates. auto module = ParseHloString(R"( HloModule sort.148.1589 ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { %parameter.0 = f32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} } )") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; absl::flat_hash_set key_set; for (const float& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } } XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { // Inputs which are sort keys in key/value sorts should have no duplicates. auto module = ParseHloString(R"( HloModule sort.148.1589 ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { %parameter.0 = s32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} } )") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; absl::flat_hash_set key_set; for (const int32& value : key_arg.data()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); } } } // namespace } // namespace xla