From 0cdf60ff8239a68326af9610e715f42c773be731 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Sep 2018 16:41:38 -0700 Subject: Make HLO liveness analysis correctly handle computations with side effect instructions. PiperOrigin-RevId: 213361904 --- .../compiler/xla/service/hlo_liveness_analysis.cc | 35 +++++++-- .../xla/service/hlo_liveness_analysis_test.cc | 84 ++++++++++++++++++++++ 2 files changed, 115 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 3a1dd471c6..5bf055f3c0 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers( } } +// Makes sure that if a live instruction is within a computation used in control +// flow operations, we mark live even other related instructions. +void PropagateLivenessThroughControlFlow( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + HloInstruction* caller = callsite.instruction(); + if (caller->opcode() == HloOpcode::kWhile) { + // If a live instruction is within the %while body or condition + // computation, mark the predicate value returned by the condition + // computation live as well. + MarkLiveAtIndex(caller->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); + } else if (caller->opcode() == HloOpcode::kConditional) { + // If a live instruction is within the true or false branches of a + // conditional, we mark the predicate operand live as well. + MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist, + workset); + } + } + } +} + } // namespace HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) @@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() { } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kWhile && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kWhile) { PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kParameter) { PropagateLivenessToParameterCallers(instruction, &live_index_map_, &worklist, &workset, call_graph_.get()); @@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() { MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); } } + PropagateLivenessThroughControlFlow(instruction, &live_index_map_, + &worklist, &workset, call_graph_.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 01b625c29c..e0ae1173c6 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); } +TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + InnerWhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + InnerWhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + OuterWhileCondition { + cond_param.2 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 + constant.5 = s32[] constant(5) + ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + } + OuterWhileBody { + body_param.2 = (s32[]) parameter(0) + get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0 + constant.6 = s32[] constant(0) + tuple.2 = (s32[]) tuple(constant.6) + inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition, + body=InnerWhileBody + constant.7 = s32[] constant(1) + add.2 = s32[] add(get-tuple-element.8, constant.7) + ROOT rtuple = (s32[]) tuple(add.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=OuterWhileCondition, + body=OuterWhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + } // namespace } // namespace xla -- cgit v1.2.3