diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-15 14:27:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-15 14:31:07 -0700 |
commit | 002b488dcd9c1eaac4f4aba2ac1301c32c6beb06 (patch) | |
tree | 2c3bc504a99d0bc112a55875d678c6dbbe197c3c /tensorflow/compiler/xla/service/hlo_ordering_test.cc | |
parent | 856438f65e5705b373413bce29758a92194ff9b6 (diff) |
Fix the HLO alias analysis and copy insertion to cope with the new kConditional instruction.
PiperOrigin-RevId: 189245979
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering_test.cc | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index a989fce632..441d790f0e 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -362,5 +362,66 @@ ENTRY while.v11 { ordering.ToString(); // Shouldn't crash. } +TEST_F(HloOrderingTest, ConditionalInstructionOrdering) { + const char* module_str = R"( +HloModule test_conditional_module + +true_branch { + param.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1) +} + +false_branch { + param.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1 + add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4) + ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4) +} + +ENTRY root { + param.3 = (pred[], (s32[], s32[])) parameter(0) + pred.1 = pred[] get-tuple-element(param.3), index=0 + cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1 + conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch + cond_res.1 = s32[] get-tuple-element(conditional), index=0 + cond_res.2 = s32[] get-tuple-element(conditional), index=1 + add.3 = s32[] add(cond_res.1, cond_res.2) + ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + + // Even though the true and false branches has no ordering, since they do not + // interfere (as they are mutually exclusive), we define the true computation + // to be before the false one. + // Similarly, any instruction in the true or false branches are considered + // before the conditional instruction. The roots are effectively "at the same + // time" WRT the conditional, but they are Phi-ed anyway. + HloInstruction* add_1 = FindInstruction(module.get(), "add.1"); + HloInstruction* add_2 = FindInstruction(module.get(), "add.2"); + HloInstruction* add_3 = FindInstruction(module.get(), "add.3"); + HloInstruction* conditional = FindInstruction(module.get(), "conditional"); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_2))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_3))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(add_3))); +} + } // namespace } // namespace xla |