aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-09-10 11:37:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 11:46:08 -0700
commit96b77a647b1391d43cae869306628b479a22daa4 (patch)
tree0ac48e6b84f3d1cd272d5a28e4a73543b709446f /tensorflow
parenta8b2dd9f72fe78cca59d525230f5358430fec45c (diff)
[TF:XLA] Migrate unit tests to use the HLO verifier (only tests where the conversion is mostly automated).
PiperOrigin-RevId: 212303594
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/BUILD12
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc20
21 files changed, 130 insertions, 102 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 6ace6d3271..1965ba1204 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -87,6 +87,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -123,6 +124,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
@@ -352,6 +354,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -402,6 +405,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -498,6 +502,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -568,6 +573,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1131,6 +1137,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -1709,6 +1716,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -2237,6 +2245,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -2315,6 +2324,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/core:test",
],
)
@@ -2428,6 +2438,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -2888,6 +2899,7 @@ tf_cc_test(
deps = [
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 6363a21c3b..5f93740887 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16ConversionFoldingTest : public HloTestBase {
+class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
protected:
+ BFloat16ConversionFoldingTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16ConversionFolding fold(&bfloat16_support_);
@@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(mul0->shape().element_type(), F32);
@@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(sub0->shape().element_type(), F32);
@@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConversions(module.get()));
+ EXPECT_FALSE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), convert1);
EXPECT_EQ(gte->shape().element_type(), F32);
@@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(FoldConversions(module.get()));
+ EXPECT_TRUE(FoldConversions(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_EQ(tuple->operand(0), gte_a);
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 933cf873e0..cef0eba14e 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16NormalizationTest : public HloTestBase {
+class BFloat16NormalizationTest : public HloVerifiedTestBase {
protected:
+ BFloat16NormalizationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16Normalization normalization(&bfloat16_support_);
@@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(Normalize(module.get()));
+ EXPECT_FALSE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
@@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
@@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), reduce);
EXPECT_EQ(reduce->called_computations().size(), 1);
@@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -317,7 +321,7 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(Normalize(module.get()));
+ EXPECT_TRUE(Normalize(module));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(dot->shape().element_type(), F32);
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index cc80b74843..34f3f914d5 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -31,7 +31,7 @@ namespace {
using ::testing::UnorderedElementsAre;
-class CallGraphTest : public HloTestBase {
+class CallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation(
@@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(1, call_graph->nodes().size());
EXPECT_TRUE(call_graph->IsFlattened());
@@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) {
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeMappingComputation(map_computation, /*callsites=*/5));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) {
HloComputation* entry_computation = module->AddEntryComputation(
MakeCallingComputation(called_computation, /*callsites=*/3));
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
// The called computation is only called from one other computation, but there
@@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(2, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(3, call_graph->nodes().size());
@@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
EXPECT_FALSE(call_graph->IsFlattened());
@@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) {
entry_computation = module->AddEntryComputation(builder.Build());
}
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(5, call_graph->nodes().size());
// Verify NearestAncestorsInSameComputation for various instructions in the
@@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) {
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
std::vector<HloComputation*> visited;
TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) {
module->AddEntryComputation(MakeScalarComputation());
HloComputation* unreachable_computation =
module->AddEmbeddedComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// Test visitation of only reachable nodes.
{
@@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) {
// Test that the call graph visitor properly propagates errors.
auto module = CreateNewModule();
module->AddEntryComputation(MakeScalarComputation());
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
Status status = call_graph->VisitNodes(
[](const CallGraphNode&) { return InternalError("Visitation failed"); });
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 039cbbff6c..8cc522a59e 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -801,6 +801,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -822,6 +823,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -946,6 +948,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -971,6 +974,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 05792795a1..2083f440fd 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -32,7 +32,7 @@ namespace cpu {
using ::testing::ElementsAre;
-class ConvCanonicalizationTest : public HloTestBase {
+class ConvCanonicalizationTest : public HloVerifiedTestBase {
public:
ConvCanonicalizationTest() {
for (int i = 0; i < 2; ++i) {
@@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
const HloInstruction* output_reshape = entry_computation->root_instruction();
EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
@@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
ConvCanonicalization conv_canonicalization(&target_machine_features);
- EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index 4db7fa446e..c9fb34be1c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) {
return count;
}
-class CpuCopyInsertionTest : public HloTestBase {
+class CpuCopyInsertionTest : public HloVerifiedTestBase {
protected:
void InsertCopies(HloModule* module) {
CpuCopyInsertion copy_insertion;
@@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*module), 3);
@@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
module->AddEntryComputation(builder.Build());
- InsertCopies(module.get());
+ InsertCopies(module);
EXPECT_EQ(CountCopies(*subcomputation), 2);
EXPECT_THAT(subcomputation->root_instruction(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
index 0f463e6de6..be1208fb2d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class CpuHloSupportCheckerTest : public HloTestBase {
+class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
CpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index 7d8e51f909..1a3d82de95 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -19,14 +19,14 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace cpu {
namespace {
-class ShapePartitionAssignerTest : public HloTestBase {
+class ShapePartitionAssignerTest : public HloVerifiedTestBase {
protected:
typedef std::vector<int64> Vec;
@@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
expected_partitions);
}
-class ShapePartitionIteratorTest : public HloTestBase {
+class ShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
};
@@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) {
}
}
-class RandomShapePartitionIteratorTest : public HloTestBase {
+class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
protected:
typedef std::vector<std::pair<int64, int64>> Partition;
RandomShapePartitionIteratorTest()
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index f11aff0573..c55206eee7 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -48,6 +48,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 22721051e5..6bf3810967 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
@@ -34,7 +34,7 @@ namespace xla {
namespace cpu {
namespace {
-class CpuFusionTest : public HloTestBase {
+class CpuFusionTest : public HloVerifiedTestBase {
protected:
CpuFusionTest() {}
@@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -75,7 +75,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
@@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -122,7 +122,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
@@ -184,7 +184,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The computation root instruction was fused. Verify the fusion instruction
// is now the root.
@@ -209,7 +209,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
<< fusion_instruction2->fused_instructions_computation()->ToString();
// Compile and execute the computation.
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
// Check the output correctness.
LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
@@ -256,7 +256,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// Run fusion.
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
auto fusion1 = result->operand(0);
auto fusion2 = result->operand(1);
@@ -315,7 +315,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
module->AddEntryComputation(builder.Build());
CpuInstructionFusion fusion;
- EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(fusion.Run(module).ValueOrDie());
// The only fusion instruction should be operand 0 of the tuple (formerly
// negate1).
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 8f6608241e..5fbd73a536 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
-class FlattenCallGraphTest : public HloTestBase {
+class FlattenCallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation() {
@@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
}
@@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
}
{
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(1, cond_node.caller_callsites().size());
}
@@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
module->AddEntryComputation(
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(7, module->computation_count());
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
@@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, module->computation_count());
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// The true and false computations must now be different.
EXPECT_EQ(3, module->computation_count());
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 569381f5b0..af953a2a16 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -832,6 +833,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -901,6 +903,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 59ade96f7d..b857fa775a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -24,14 +24,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class GpuHloScheduleTest : public HloTestBase {
+class GpuHloScheduleTest : public HloVerifiedTestBase {
protected:
using HloVec = std::vector<const HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
index 0a4089df4c..27a4d0b601 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,7 +25,7 @@ namespace {
using ::testing::HasSubstr;
-class GpuHloSupportCheckerTest : public HloTestBase {
+class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
protected:
GpuHloSupportChecker& checker() { return checker_; }
@@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK(checker().Run(module.get()).status());
+ TF_ASSERT_OK(checker().Run(module).status());
}
TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Status status = checker().Run(module.get()).status();
+ Status status = checker().Run(module).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 8f0dedfa40..c4f43cc9a6 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -21,14 +21,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
-class StreamAssignmentTest : public HloTestBase {
+class StreamAssignmentTest : public HloVerifiedTestBase {
protected:
std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 00a25db467..957c4a6891 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
-class MinimumMemoryForSequenceTest : public HloTestBase {};
+class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
auto module = CreateNewModule();
@@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
- HloSchedule schedule(module.get());
+ HloSchedule schedule(module);
schedule.set_sequence(cond_computation,
{cond_param, cond_iter, cond_data, cond_lt});
schedule.set_sequence(body_computation, {body_param});
@@ -233,7 +233,7 @@ class HeapSimulatorTracker {
HeapSimulator::Result result_;
};
-class HeapSimulatorTest : public HloTestBase {
+class HeapSimulatorTest : public HloVerifiedTestBase {
protected:
HeapSimulatorTest() {}
~HeapSimulatorTest() override {}
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 585c95972b..d9848cee0b 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -20,13 +20,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
namespace xla {
namespace {
-class HloReachabilityTest : public HloTestBase {};
+class HloReachabilityTest : public HloVerifiedTestBase {};
TEST_F(HloReachabilityTest, Reachability) {
// Construct and test a reachability graph of the following form:
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 4b611fe450..f7e82fb1f8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-class HloRematerializationTest : public HloTestBase {
+class HloRematerializationTest : public HloVerifiedTestBase {
protected:
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
@@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// with rematerialization so pick a memory limit between these values (14KB).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/14 * 1024, module.get()));
+ /*memory_limit_bytes=*/14 * 1024, module));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/20 * 1024, module.get()));
+ /*memory_limit_bytes=*/20 * 1024, module));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// bit lower (17KB) to force rematerialization of the entry computation.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/17 * 1024, module.get()));
+ /*memory_limit_bytes=*/17 * 1024, module));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024, module.get()));
+ /*memory_limit_bytes=*/15 * 1024, module));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// ~12K so pick something slightly larger.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/13 * 1024, module.get()));
+ /*memory_limit_bytes=*/13 * 1024, module));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
+ /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024, module.get()));
+ /*memory_limit_bytes=*/22 * 1024, module));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024, module.get()));
+ /*memory_limit_bytes=*/22 * 1024, module));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 1e2b31a1f2..6fd734a2b9 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -24,7 +24,7 @@ namespace {
using ::tensorflow::GraphDef;
-class HloTfGraphBuilderTest : public HloTestBase {
+class HloTfGraphBuilderTest : public HloVerifiedTestBase {
protected:
HloTfGraphBuilderTest() {}
HloTfGraphBuilder generator_;
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index 39b693872d..516754e211 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class TupleSimplifierTest : public HloTestBase {
+class TupleSimplifierTest : public HloVerifiedTestBase {
protected:
void Run(HloModule* module, bool change_expected) {
TupleSimplifier simplifier;
@@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
@@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
}
TEST_F(TupleSimplifierTest, GteOfTuple) {
@@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) {
EXPECT_THAT(computation->root_instruction(), gte);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param1);
}
@@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) {
EXPECT_THAT(computation->root_instruction(),
op::Negate(op::GetTupleElement(op::Tuple())));
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
}
@@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
EXPECT_THAT(computation->root_instruction(), element);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/true);
+ Run(module, /*change_expected=*/true);
EXPECT_THAT(computation->root_instruction(), tuple_param);
}
@@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
- Run(module.get(), /*change_expected=*/false);
+ Run(module, /*change_expected=*/false);
EXPECT_THAT(computation->root_instruction(), tuple);
}
@@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
entry = module->AddEntryComputation(builder.Build());
}
- Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+ Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
EXPECT_THAT(c0->root_instruction(), p0);
EXPECT_THAT(c1->root_instruction(), p1);