diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1f00aa41dc..b589cd573d 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -47,7 +48,9 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { XLA_REGISTER_GRAPH_RENDERER(DotRenderer); -TEST(HloGraphDumperTest, NestedFusion) { +class HloGraphDumperTest : public HloTestBase {}; + +TEST_F(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); // Build param0 + param1 + param2 + param3 + param4. @@ -64,10 +67,9 @@ TEST(HloGraphDumperTest, NestedFusion) { sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, sums[i], params[i + 2]))); } - - HloModule m(TestName()); - m.AddEntryComputation(b.Build()); - HloComputation* root_computation = m.entry_computation(); + auto m = CreateNewModule(); + m->AddEntryComputation(b.Build()); + HloComputation* root_computation = m->entry_computation(); // Fuse into fusion(param0 + param1 + param2 + param3 + param4). auto* outer_fusion = root_computation->CreateFusionInstruction( @@ -117,13 +119,13 @@ TEST(HloGraphDumperTest, NestedFusion) { HasSubstr(inner_sum->name())); } -TEST(HloGraphDumperTest, Constant) { +TEST_F(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(-42))); instruction->set_name("i_am_a_constant_root_instruction"); - HloModule m(TestName()); - HloComputation* root_computation = m.AddEntryComputation(b.Build()); + auto m = CreateNewModule(); + HloComputation* root_computation = m->AddEntryComputation(b.Build()); string graph = hlo_graph_dumper::DumpGraph( *root_computation, /*label=*/"an_empty_graph", DebugOptions()); EXPECT_THAT(graph, HasSubstr("an_empty_graph")); |