diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_constant_folding_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_constant_folding_test.cc | 37 |
1 files changed, 18 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 07cd1efc12..3e0def5d26 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloTestBase; +using HloConstantFoldingTest = HloVerifiedTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral<F32>( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->Literal::CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { root->literal().EachCell<NativeT>( [&](absl::Span<const int64> indices, NativeT value) { std::vector<int64> rindexes = Permute(permutation, indices); - matched = matched && (value == literal_clone->Get<NativeT>(rindexes)); + matched = matched && (value == literal_clone.Get<NativeT>(rindexes)); }); EXPECT_TRUE(matched); } @@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(kConstantFoldReduce)); + ParseAndVerifyModule(kConstantFoldReduce); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_TRUE(result); - EXPECT_EQ(6, module->entry_computation() + EXPECT_EQ(6, module() + .entry_computation() ->root_instruction() ->literal() .GetFirstElement<int32>()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(kConstantFoldReduce)); - HloInstruction* add = module->computations().begin()->root_instruction(); + ParseAndVerifyModule(kConstantFoldReduce); + HloInstruction* add = module().computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_FALSE(result); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } } // namespace |