aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/fusion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/fusion_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 519de64d1b..f9676dfc19 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -521,6 +521,42 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
*ExecuteAndTransfer(std::move(hlo_module), {}));
}
+// When a constant (or other op) which has multiple users is imported
+// into a fusion, it should remain shared, rather than being duplicated
+// within the fusion.
+XLA_TEST_F(FusionTest, SharedConstant) {
+ auto hlo_module = CreateNewModule();
+
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
+ auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
+ auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
+ auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(
+ {add4, add3, add2, add1, const1},
+ HloInstruction::FusionKind::kLoop);
+
+ HloComputation* entry_comp = hlo_module->entry_computation();
+
+ // entry computation contains the constant(0) and the fusion
+ EXPECT_EQ(entry_comp->instructions().size(), 2);
+
+ // fused instruction contains the constant(2), the parameter, and 4 adds
+ EXPECT_EQ(entry_comp->root_instruction()->fused_instructions().size(), 6);
+
+ LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
XLA_TEST_F(FusionTest, Subtract2D) {