aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc18
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 1f00aa41dc..b589cd573d 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -47,7 +48,9 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
-TEST(HloGraphDumperTest, NestedFusion) {
+class HloGraphDumperTest : public HloTestBase {};
+
+TEST_F(HloGraphDumperTest, NestedFusion) {
HloComputation::Builder b("b");
// Build param0 + param1 + param2 + param3 + param4.
@@ -64,10 +67,9 @@ TEST(HloGraphDumperTest, NestedFusion) {
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, sums[i], params[i + 2])));
}
-
- HloModule m(TestName());
- m.AddEntryComputation(b.Build());
- HloComputation* root_computation = m.entry_computation();
+ auto m = CreateNewModule();
+ m->AddEntryComputation(b.Build());
+ HloComputation* root_computation = m->entry_computation();
// Fuse into fusion(param0 + param1 + param2 + param3 + param4).
auto* outer_fusion = root_computation->CreateFusionInstruction(
@@ -117,13 +119,13 @@ TEST(HloGraphDumperTest, NestedFusion) {
HasSubstr(inner_sum->name()));
}
-TEST(HloGraphDumperTest, Constant) {
+TEST_F(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
instruction->set_name("i_am_a_constant_root_instruction");
- HloModule m(TestName());
- HloComputation* root_computation = m.AddEntryComputation(b.Build());
+ auto m = CreateNewModule();
+ HloComputation* root_computation = m->AddEntryComputation(b.Build());
string graph = hlo_graph_dumper::DumpGraph(
*root_computation, /*label=*/"an_empty_graph", DebugOptions());
EXPECT_THAT(graph, HasSubstr("an_empty_graph"));