aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_constant_folding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc37
1 files changed, 18 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 07cd1efc12..3e0def5d26 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using HloConstantFoldingTest = HloTestBase;
+using HloConstantFoldingTest = HloVerifiedTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
@@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
- auto literal_clone = literal->Literal::CloneToUnique();
+ auto literal_clone = literal.Clone();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
root->literal().EachCell<NativeT>(
[&](absl::Span<const int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
- matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
+ matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
});
EXPECT_TRUE(matched);
}
@@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"(
})";
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
+ ParseAndVerifyModule(kConstantFoldReduce);
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_TRUE(result);
- EXPECT_EQ(6, module->entry_computation()
+ EXPECT_EQ(6, module()
+ .entry_computation()
->root_instruction()
->literal()
.GetFirstElement<int32>());
}
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
- HloInstruction* add = module->computations().begin()->root_instruction();
+ ParseAndVerifyModule(kConstantFoldReduce);
+ HloInstruction* add = module().computations().begin()->root_instruction();
LayoutUtil::ClearLayout(add->mutable_shape());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_FALSE(result);
- EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+ EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}
} // namespace