aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index f795a6ef62..ea5749581b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1077,6 +1077,48 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
root2->operand(1)->operand(0)->shape()));
}
+TEST_F(HloInstructionTest, IsRandomFusable) {
+ auto shape = ShapeUtil::MakeShape(F32, {2, 2});
+ {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = CreateNewModule();
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(0.0)));
+ auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(1.0)));
+ auto rng = builder.AddInstruction(HloInstruction::CreateRng(
+ shape, RandomDistribution::RNG_NORMAL, {const0, const1}));
+
+ auto* computation = hlo_module->AddEntryComputation(builder.Build());
+ computation->CreateFusionInstruction({rng, const0, const1},
+ HloInstruction::FusionKind::kLoop);
+
+ auto* root = computation->root_instruction();
+
+ EXPECT_EQ(HloOpcode::kFusion, root->opcode());
+ }
+ {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = CreateNewModule();
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(0.0)));
+ auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(1.0)));
+ auto rng = builder.AddInstruction(HloInstruction::CreateRng(
+ shape, RandomDistribution::RNG_NORMAL, {const0, const1}));
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ shape, HloOpcode::kNegate, rng));
+ auto* computation = hlo_module->AddEntryComputation(builder.Build());
+ computation->CreateFusionInstruction({rng, const0, const1},
+ HloInstruction::FusionKind::kLoop);
+
+ auto* root = computation->root_instruction();
+
+ EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode());
+ }
+}
+
+
TEST_F(HloInstructionTest, CloneSuffixNames) {
// Test that the suffix string added to cloned instructions is not
// duplicated. Rather a numeric incrementing value should be appended. That