aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction_test.cc
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-06-25 17:27:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 17:30:25 -0700
commit89013c6f76568736cd6d8395f73db53045303412 (patch)
treeaa16b75243df34b120eb0799d8c7534b494fcfad /tensorflow/compiler/xla/service/hlo_instruction_test.cc
parentee8703f342269dca881c17c6db3177355fcd18c7 (diff)
[XLA] Avoid fusion nodes to have duplicate operands during replacing uses.
PiperOrigin-RevId: 202049336
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 120162a956..3847d68efa 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1137,6 +1137,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
}
+TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
+ // Fused expression:
+ //
+ // x y
+ // | |
+ // | transpose
+ // \ /
+ // dot
+ const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {dot, reshape}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
+
+ EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
+ EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
+}
+
TEST_F(HloInstructionTest, FusionEquality) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());