aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-17 16:41:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 16:46:10 -0700
commit0cdf60ff8239a68326af9610e715f42c773be731 (patch)
treecc9fc90379f26fc4aed7255bd2b803fa879967db
parent0b80d098704c72f627f37bfeee0ae19788c06fa8 (diff)
Make HLO liveness analysis correctly handle computations with side effect instructions.
PiperOrigin-RevId: 213361904
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc84
2 files changed, 115 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 3a1dd471c6..5bf055f3c0 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers(
}
}
+// Makes sure that if a live instruction is within a computation used in control
+// flow operations, we mark live even other related instructions.
+void PropagateLivenessThroughControlFlow(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset, CallGraph* call_graph) {
+ const CallGraphNode& call_graph_node =
+ call_graph->GetNode(instruction->parent());
+ if (call_graph_node.context() == CallContext::kSequential) {
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ HloInstruction* caller = callsite.instruction();
+ if (caller->opcode() == HloOpcode::kWhile) {
+ // If a live instruction is within the %while body or condition
+ // computation, mark the predicate value returned by the condition
+ // computation live as well.
+ MarkLiveAtIndex(caller->while_condition()->root_instruction(), {},
+ live_index_map, worklist, workset);
+ } else if (caller->opcode() == HloOpcode::kConditional) {
+ // If a live instruction is within the true or false branches of a
+ // conditional, we mark the predicate operand live as well.
+ MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist,
+ workset);
+ }
+ }
+ }
+}
+
} // namespace
HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
@@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() {
} else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kWhile &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kWhile) {
PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kParameter &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kParameter) {
PropagateLivenessToParameterCallers(instruction, &live_index_map_,
&worklist, &workset,
call_graph_.get());
@@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() {
MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
}
}
+ PropagateLivenessThroughControlFlow(instruction, &live_index_map_,
+ &worklist, &workset, call_graph_.get());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 01b625c29c..e0ae1173c6 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
}
+TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ InnerWhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ InnerWhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ OuterWhileCondition {
+ cond_param.2 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
+ constant.5 = s32[] constant(5)
+ ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
+ }
+ OuterWhileBody {
+ body_param.2 = (s32[]) parameter(0)
+ get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0
+ constant.6 = s32[] constant(0)
+ tuple.2 = (s32[]) tuple(constant.6)
+ inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition,
+ body=InnerWhileBody
+ constant.7 = s32[] constant(1)
+ add.2 = s32[] add(get-tuple-element.8, constant.7)
+ ROOT rtuple = (s32[]) tuple(add.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=OuterWhileCondition,
+ body=OuterWhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
} // namespace
} // namespace xla