aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index c58af04bad..896f6ea842 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
namespace gpu {
@@ -32,7 +31,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape2, computation->root_instruction());
EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
@@ -49,7 +48,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose2, computation->root_instruction());
EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
@@ -89,7 +88,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) {
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
@@ -108,7 +107,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) {
HloInstruction::CreateGetTupleElement(data_shape, param, 1));
builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1));
- auto module = MakeUnique<HloModule>(TestName());
+ auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
@@ -124,3 +123,7 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) {
} // namespace gpu
} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}