aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_ordering_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-15 14:27:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 14:31:07 -0700
commit002b488dcd9c1eaac4f4aba2ac1301c32c6beb06 (patch)
tree2c3bc504a99d0bc112a55875d678c6dbbe197c3c /tensorflow/compiler/xla/service/hlo_ordering_test.cc
parent856438f65e5705b373413bce29758a92194ff9b6 (diff)
Fix the HLO alias analysis and copy insertion to cope with the new kConditional instruction.
PiperOrigin-RevId: 189245979
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index a989fce632..441d790f0e 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -362,5 +362,66 @@ ENTRY while.v11 {
ordering.ToString(); // Shouldn't crash.
}
+TEST_F(HloOrderingTest, ConditionalInstructionOrdering) {
+ const char* module_str = R"(
+HloModule test_conditional_module
+
+true_branch {
+ param.1 = (s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0
+ get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1
+ add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2)
+ ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1)
+}
+
+false_branch {
+ param.2 = (s32[], s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0
+ get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1
+ add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4)
+ ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4)
+}
+
+ENTRY root {
+ param.3 = (pred[], (s32[], s32[])) parameter(0)
+ pred.1 = pred[] get-tuple-element(param.3), index=0
+ cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1
+ conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch
+ cond_res.1 = s32[] get-tuple-element(conditional), index=0
+ cond_res.2 = s32[] get-tuple-element(conditional), index=1
+ add.3 = s32[] add(cond_res.1, cond_res.2)
+ ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+ DependencyHloOrdering ordering(module.get());
+
+ // Even though the true and false branches has no ordering, since they do not
+ // interfere (as they are mutually exclusive), we define the true computation
+ // to be before the false one.
+ // Similarly, any instruction in the true or false branches are considered
+ // before the conditional instruction. The roots are effectively "at the same
+ // time" WRT the conditional, but they are Phi-ed anyway.
+ HloInstruction* add_1 = FindInstruction(module.get(), "add.1");
+ HloInstruction* add_2 = FindInstruction(module.get(), "add.2");
+ HloInstruction* add_3 = FindInstruction(module.get(), "add.3");
+ HloInstruction* conditional = FindInstruction(module.get(), "conditional");
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+ dataflow->GetValueDefinedAt(add_2)));
+ EXPECT_TRUE(
+ ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
+ dataflow->GetValueDefinedAt(conditional)));
+ EXPECT_TRUE(
+ ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+ dataflow->GetValueDefinedAt(conditional)));
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1),
+ dataflow->GetValueDefinedAt(add_3)));
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2),
+ dataflow->GetValueDefinedAt(add_3)));
+}
+
} // namespace
} // namespace xla