diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-06-15 11:10:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-15 11:13:08 -0700 |
commit | b62d76d932f93ff324d2598cdeac792fa61135a4 (patch) | |
tree | f6c9dda35d6ae6263cb16984278fbebf93b46574 /tensorflow/compiler/xla/service/hlo_computation.cc | |
parent | 1ca4b6f797a168036e2708faf45753b333f467dc (diff) |
[XLA] Switch PostOrder accessors to use std::vector instead of std::list.
std::list is just hilariously inefficient and the postorder list creation has
been rewritten not to not depend on splicing anymore so there's no need for the
list. While there remove the old unused postorder list creation code.
PiperOrigin-RevId: 200743677
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 67 |
1 files changed, 16 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ef8bb030fb..74173a1685 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -263,46 +263,11 @@ void HloComputation::set_root_instruction( namespace { -// Helper class which computes the post order of an expression rooted at a -// particular instruction. -class InstructionPostOrderer : public DfsHloVisitorWithDefault { - public: - // added_instructions is the set of instructions which have already been - // accounted for in the post order in previous invocations of - // GetOrder. Without this mechanism, instructions which are predecessors of - // multiple root instructions of the computation can be added to the post - // order more than once. - static std::list<HloInstruction*> GetOrder( - HloInstruction* root, - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) { - InstructionPostOrderer orderer(added_instructions); - TF_CHECK_OK(root->Accept(&orderer)); - return std::move(orderer.post_order_); - } - - private: - explicit InstructionPostOrderer( - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) - : added_instructions_(added_instructions) {} - ~InstructionPostOrderer() override {} - - Status DefaultAction(HloInstruction* hlo_instruction) override { - if (added_instructions_->count(hlo_instruction) == 0) { - post_order_.push_back(hlo_instruction); - added_instructions_->insert(hlo_instruction); - } - return Status::OK(); - } - - std::list<HloInstruction*> post_order_; - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_; -}; - // Helper which builds a post order of the HLO call graph. void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet<HloComputation*>* visited, - std::list<HloComputation*>* post_order) { + std::vector<HloComputation*>* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -314,9 +279,9 @@ void ComputeComputationPostOrder( } } -std::list<HloInstruction*> ComputeInstructionPostOrder( - HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) { - std::list<HloInstruction*> post_order; +void ComputeInstructionPostOrder( + std::vector<HloInstruction*>* post_order, HloInstruction* root, + tensorflow::gtl::FlatSet<HloInstruction*>* visited) { std::vector<std::pair<HloInstruction*, bool>> dfs_stack; dfs_stack.emplace_back(root, false); while (!dfs_stack.empty()) { @@ -326,7 +291,7 @@ std::list<HloInstruction*> ComputeInstructionPostOrder( if (!visited->insert(current.first).second) { continue; } - post_order.push_back(current.first); + post_order->push_back(current.first); } else { if (visited->count(current.first)) { dfs_stack.pop_back(); @@ -347,14 +312,14 @@ std::list<HloInstruction*> ComputeInstructionPostOrder( } } } - return post_order; } } // namespace -std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { - std::list<HloInstruction*> post_order; - std::list<HloInstruction*> trace_instructions; +std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { + std::vector<HloInstruction*> post_order; + post_order.reserve(instruction_count()); + std::vector<HloInstruction*> trace_instructions; tensorflow::gtl::FlatSet<HloInstruction*> added_instructions; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { @@ -363,21 +328,21 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice( - post_order.end(), - ComputeInstructionPostOrder(instruction.get(), &added_instructions)); + ComputeInstructionPostOrder(&post_order, instruction.get(), + &added_instructions); } } - post_order.splice(post_order.end(), trace_instructions); + post_order.insert(post_order.end(), trace_instructions.begin(), + trace_instructions.end()); CHECK_EQ(instructions_.size(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } -std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList() +std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() const { tensorflow::gtl::FlatSet<HloComputation*> visited; - std::list<HloComputation*> post_order; + std::vector<HloComputation*> post_order; // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after @@ -648,7 +613,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() const { - const std::list<HloInstruction*> all = MakeInstructionPostOrder(); + const auto& all = MakeInstructionPostOrder(); auto result = MakeUnique<HloReachabilityMap>(all); std::vector<HloInstruction*> inputs; |