diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index c58af04bad..896f6ea842 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace gpu { @@ -32,7 +31,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = MakeUnique<HloModule>(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -49,7 +48,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = MakeUnique<HloModule>(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -89,7 +88,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose)); - auto module = MakeUnique<HloModule>(TestName()); + auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -108,7 +107,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) { HloInstruction::CreateGetTupleElement(data_shape, param, 1)); builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1)); - auto module = MakeUnique<HloModule>(TestName()); + auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -124,3 +123,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) { } // namespace gpu } // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} |