aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc38
1 files changed, 34 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 1963d9eef7..8d0522bd8f 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -33,7 +33,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndOperandElementReusingConsumerNotFused) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* broadcast2 =
@@ -53,7 +53,7 @@ TEST_F(InstructionFusionTest,
NonCostlyProducerAndOperandElementReusingConsumerFused) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
HloInstruction* broadcast2 =
@@ -73,7 +73,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* reshape2 = builder.AddInstruction(
@@ -92,7 +92,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* transpose2 = builder.AddInstruction(
@@ -606,5 +606,35 @@ TEST_F(InstructionFusionTest, FuseScalarConstant) {
op::Parameter()));
}
+// Check that we limit the number of operands to fusions we create.
+TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
+ constexpr int64 kNumParams = 200;
+ ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion);
+
+ // Compute p0 + p1 + ... + pN.
+ HloComputation::Builder b(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+ auto param0 =
+ b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
+ auto sum = param0;
+ for (int64 i = 1; i < kNumParams; ++i) {
+ auto param =
+ b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"));
+ sum = b.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param));
+ }
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(b.Build());
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ for (const HloInstruction* instr : computation->instructions()) {
+ EXPECT_LE(instr->operand_count(),
+ GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion)
+ << instr->ToString();
+ }
+}
+
} // namespace gpu
} // namespace xla