aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar Nick Desaulniers <ndesaulniers@google.com>2018-02-14 13:52:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-14 14:03:01 -0800
commit04348511079ffee7cb169bb3bef42a47ec1736c6 (patch)
tree5fc82d23eaa7d2ccd4c30dfac664b865a3097d6b /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parentc5c0ec07321e4911d803ba8aa9a1f4049a88710f (diff)
[XLA] Add reproducer that shows perf issues in HloDataflowAnalysis::UpdateTupleValueSet, then optimize that method.
HloDataflowAnalysis::UpdateTupleValueSet starts to show up in profiles for while bodies that have many GetTupleElement nodes. Use a set to keep track of which HloInstructions we need to propagate DFA for. PiperOrigin-RevId: 185739365
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc28
1 files changed, 18 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index d25fc5d741..ccbbe8f196 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -585,16 +585,23 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
void HloDataflowAnalysis::Propagate() {
std::queue<HloInstruction*> worklist;
+ tensorflow::gtl::FlatSet<HloInstruction*> workset;
+ auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
+ if (workset.insert(instruction).second) {
+ worklist.push(instruction);
+ }
+ };
for (HloComputation* computation : module_->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- worklist.push(instruction);
+ add_to_worklist(instruction);
}
}
while (!worklist.empty()) {
HloInstruction* instruction = worklist.front();
worklist.pop();
+ workset.erase(workset.find(instruction));
VLOG(3) << "Worklist top: " << instruction->name();
VLOG(3) << ToString();
@@ -608,9 +615,10 @@ void HloDataflowAnalysis::Propagate() {
VLOG(4) << "New value set for " << instruction->name() << ": "
<< GetInstructionValueSet(instruction);
- // Instruction value was updated. Add users to work list.
+ // Instruction value was updated. Add users to work list if we haven't
+ // already.
for (HloInstruction* user : instruction->users()) {
- worklist.push(user);
+ add_to_worklist(user);
// If user sequentially calls a computation, then the respective
// parameter(s) of the computation need to be updated.
@@ -625,10 +633,10 @@ void HloDataflowAnalysis::Propagate() {
// Note that the same instruction can be used in both operand 1 and
// operand 2.
if (user->operand(1) == instruction) {
- worklist.push(user->true_computation()->parameter_instruction(0));
+ add_to_worklist(user->true_computation()->parameter_instruction(0));
}
if (user->operand(2) == instruction) {
- worklist.push(user->false_computation()->parameter_instruction(0));
+ add_to_worklist(user->false_computation()->parameter_instruction(0));
}
} else {
for (HloComputation* called_computation : user->called_computations()) {
@@ -636,7 +644,7 @@ void HloDataflowAnalysis::Propagate() {
call_graph_->GetNode(called_computation);
if (call_graph_node.context() == CallContext::kSequential) {
for (int64 operand_number : user->OperandIndices(instruction)) {
- worklist.push(
+ add_to_worklist(
called_computation->parameter_instruction(operand_number));
}
}
@@ -652,13 +660,13 @@ void HloDataflowAnalysis::Propagate() {
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
if ((callsite.instruction()->opcode() == HloOpcode::kCall) ||
(callsite.instruction()->opcode() == HloOpcode::kConditional)) {
- worklist.push(callsite.instruction());
+ add_to_worklist(callsite.instruction());
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
// Add the while itself, and the body and condition parameters.
- worklist.push(callsite.instruction());
- worklist.push(
+ add_to_worklist(callsite.instruction());
+ add_to_worklist(
callsite.instruction()->while_body()->parameter_instruction(0));
- worklist.push(
+ add_to_worklist(
callsite.instruction()->while_condition()->parameter_instruction(
0));
}