diff options
author | 2018-06-25 17:27:30 -0700 | |
---|---|---|
committer | 2018-06-25 17:30:25 -0700 | |
commit | 89013c6f76568736cd6d8395f73db53045303412 (patch) | |
tree | aa16b75243df34b120eb0799d8c7534b494fcfad /tensorflow/compiler/xla/service/hlo_instruction_test.cc | |
parent | ee8703f342269dca881c17c6db3177355fcd18c7 (diff) |
[XLA] Avoid fusion nodes to have duplicate operands during replacing uses.
PiperOrigin-RevId: 202049336
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction_test.cc | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 120162a956..3847d68efa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1137,6 +1137,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); } +TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { + // Fused expression: + // + // x y + // | | + // | transpose + // \ / + // dot + const Shape s = ShapeUtil::MakeShape(F32, {10, 10}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok()); + + EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y)); + EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1); +} + TEST_F(HloInstructionTest, FusionEquality) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); |