diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/fusion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/fusion_test.cc | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 519de64d1b..f9676dfc19 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -521,6 +521,42 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { *ExecuteAndTransfer(std::move(hlo_module), {})); } +// When a constant (or other op) which has multiple users is imported +// into a fusion, it should remain shared, rather than being duplicated +// within the fusion. +XLA_TEST_F(FusionTest, SharedConstant) { + auto hlo_module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<int32>({0}))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<int32>({2}))); + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); + auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); + auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); + auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction( + {add4, add3, add2, add1, const1}, + HloInstruction::FusionKind::kLoop); + + HloComputation* entry_comp = hlo_module->entry_computation(); + + // entry computation contains the constant(0) and the fusion + EXPECT_EQ(entry_comp->instructions().size(), 2); + + // fused instruction contains the constant(2), the parameter, and 4 adds + EXPECT_EQ(entry_comp->root_instruction()->fused_instructions().size(), 6); + + LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}), + *ExecuteAndTransfer(std::move(hlo_module), {})); +} + XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); } XLA_TEST_F(FusionTest, Subtract2D) { |