aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 98ba162cd9..229eb23f12 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -606,5 +606,35 @@ TEST_F(InstructionFusionTest, FuseScalarConstant) {
op::Parameter()));
}
+// Check that we limit the number of operands to fusions we create.
+TEST_F(InstructionFusionTest, AvoidsLargeFusion) {
+ constexpr int64 kNumParams = 200;
+ ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsPerFusion);
+
+ // Compute p0 + p1 + ... + pN.
+ HloComputation::Builder b(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
+ auto param0 =
+ b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
+ auto sum = param0;
+ for (int64 i = 1; i < kNumParams; ++i) {
+ auto param =
+ b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"));
+ sum = b.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param));
+ }
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(b.Build());
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ for (const HloInstruction* instr : computation->instructions()) {
+ EXPECT_LE(instr->operand_count(),
+ GpuInstructionFusion::kMaxOperandsPerFusion)
+ << instr->ToString();
+ }
+}
+
} // namespace gpu
} // namespace xla