diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction_test.cc | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index f795a6ef62..ea5749581b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1077,6 +1077,48 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { root2->operand(1)->operand(0)->shape())); } +TEST_F(HloInstructionTest, IsRandomFusable) { + auto shape = ShapeUtil::MakeShape(F32, {2, 2}); + { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0<float>(0.0))); + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0<float>(1.0))); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + shape, RandomDistribution::RNG_NORMAL, {const0, const1})); + + auto* computation = hlo_module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({rng, const0, const1}, + HloInstruction::FusionKind::kLoop); + + auto* root = computation->root_instruction(); + + EXPECT_EQ(HloOpcode::kFusion, root->opcode()); + } + { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0<float>(0.0))); + auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0<float>(1.0))); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + shape, RandomDistribution::RNG_NORMAL, {const0, const1})); + builder.AddInstruction(HloInstruction::CreateUnary( + shape, HloOpcode::kNegate, rng)); + auto* computation = hlo_module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({rng, const0, const1}, + HloInstruction::FusionKind::kLoop); + + auto* root = computation->root_instruction(); + + EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode()); + } +} + + TEST_F(HloInstructionTest, CloneSuffixNames) { // Test that the suffix string added to cloned instructions is not // duplicated. Rather a numeric incrementing value should be appended. That |