diff options
author | Nick Desaulniers <ndesaulniers@google.com> | 2018-02-14 13:52:18 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-14 14:03:01 -0800 |
commit | 04348511079ffee7cb169bb3bef42a47ec1736c6 (patch) | |
tree | 5fc82d23eaa7d2ccd4c30dfac664b865a3097d6b /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | |
parent | c5c0ec07321e4911d803ba8aa9a1f4049a88710f (diff) |
[XLA] Add reproducer that shows perf issues in HloDataflowAnalysis::UpdateTupleValueSet, then optimize that method.
HloDataflowAnalysis::UpdateTupleValueSet starts to show up in profiles for while bodies that have many GetTupleElement nodes.
Use a set to keep track of which HloInstructions we need to propagate DFA for.
PiperOrigin-RevId: 185739365
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | 28 |
1 files changed, 18 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d25fc5d741..ccbbe8f196 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -585,16 +585,23 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue<HloInstruction*> worklist; + tensorflow::gtl::FlatSet<HloInstruction*> workset; + auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { + if (workset.insert(instruction).second) { + worklist.push(instruction); + } + }; for (HloComputation* computation : module_->computations()) { for (HloInstruction* instruction : computation->instructions()) { - worklist.push(instruction); + add_to_worklist(instruction); } } while (!worklist.empty()) { HloInstruction* instruction = worklist.front(); worklist.pop(); + workset.erase(workset.find(instruction)); VLOG(3) << "Worklist top: " << instruction->name(); VLOG(3) << ToString(); @@ -608,9 +615,10 @@ void HloDataflowAnalysis::Propagate() { VLOG(4) << "New value set for " << instruction->name() << ": " << GetInstructionValueSet(instruction); - // Instruction value was updated. Add users to work list. + // Instruction value was updated. Add users to work list if we haven't + // already. for (HloInstruction* user : instruction->users()) { - worklist.push(user); + add_to_worklist(user); // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. @@ -625,10 +633,10 @@ void HloDataflowAnalysis::Propagate() { // Note that the same instruction can be used in both operand 1 and // operand 2. if (user->operand(1) == instruction) { - worklist.push(user->true_computation()->parameter_instruction(0)); + add_to_worklist(user->true_computation()->parameter_instruction(0)); } if (user->operand(2) == instruction) { - worklist.push(user->false_computation()->parameter_instruction(0)); + add_to_worklist(user->false_computation()->parameter_instruction(0)); } } else { for (HloComputation* called_computation : user->called_computations()) { @@ -636,7 +644,7 @@ void HloDataflowAnalysis::Propagate() { call_graph_->GetNode(called_computation); if (call_graph_node.context() == CallContext::kSequential) { for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( + add_to_worklist( called_computation->parameter_instruction(operand_number)); } } @@ -652,13 +660,13 @@ void HloDataflowAnalysis::Propagate() { for (const CallSite& callsite : call_graph_node.caller_callsites()) { if ((callsite.instruction()->opcode() == HloOpcode::kCall) || (callsite.instruction()->opcode() == HloOpcode::kConditional)) { - worklist.push(callsite.instruction()); + add_to_worklist(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. - worklist.push(callsite.instruction()); - worklist.push( + add_to_worklist(callsite.instruction()); + add_to_worklist( callsite.instruction()->while_body()->parameter_instruction(0)); - worklist.push( + add_to_worklist( callsite.instruction()->while_condition()->parameter_instruction( 0)); } |