diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 67 |
1 files changed, 16 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ef8bb030fb..74173a1685 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -263,46 +263,11 @@ void HloComputation::set_root_instruction( namespace { -// Helper class which computes the post order of an expression rooted at a -// particular instruction. -class InstructionPostOrderer : public DfsHloVisitorWithDefault { - public: - // added_instructions is the set of instructions which have already been - // accounted for in the post order in previous invocations of - // GetOrder. Without this mechanism, instructions which are predecessors of - // multiple root instructions of the computation can be added to the post - // order more than once. - static std::list<HloInstruction*> GetOrder( - HloInstruction* root, - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) { - InstructionPostOrderer orderer(added_instructions); - TF_CHECK_OK(root->Accept(&orderer)); - return std::move(orderer.post_order_); - } - - private: - explicit InstructionPostOrderer( - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) - : added_instructions_(added_instructions) {} - ~InstructionPostOrderer() override {} - - Status DefaultAction(HloInstruction* hlo_instruction) override { - if (added_instructions_->count(hlo_instruction) == 0) { - post_order_.push_back(hlo_instruction); - added_instructions_->insert(hlo_instruction); - } - return Status::OK(); - } - - std::list<HloInstruction*> post_order_; - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_; -}; - // Helper which builds a post order of the HLO call graph. void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet<HloComputation*>* visited, - std::list<HloComputation*>* post_order) { + std::vector<HloComputation*>* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -314,9 +279,9 @@ void ComputeComputationPostOrder( } } -std::list<HloInstruction*> ComputeInstructionPostOrder( - HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) { - std::list<HloInstruction*> post_order; +void ComputeInstructionPostOrder( + std::vector<HloInstruction*>* post_order, HloInstruction* root, + tensorflow::gtl::FlatSet<HloInstruction*>* visited) { std::vector<std::pair<HloInstruction*, bool>> dfs_stack; dfs_stack.emplace_back(root, false); while (!dfs_stack.empty()) { @@ -326,7 +291,7 @@ std::list<HloInstruction*> ComputeInstructionPostOrder( if (!visited->insert(current.first).second) { continue; } - post_order.push_back(current.first); + post_order->push_back(current.first); } else { if (visited->count(current.first)) { dfs_stack.pop_back(); @@ -347,14 +312,14 @@ std::list<HloInstruction*> ComputeInstructionPostOrder( } } } - return post_order; } } // namespace -std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { - std::list<HloInstruction*> post_order; - std::list<HloInstruction*> trace_instructions; +std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { + std::vector<HloInstruction*> post_order; + post_order.reserve(instruction_count()); + std::vector<HloInstruction*> trace_instructions; tensorflow::gtl::FlatSet<HloInstruction*> added_instructions; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { @@ -363,21 +328,21 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice( - post_order.end(), - ComputeInstructionPostOrder(instruction.get(), &added_instructions)); + ComputeInstructionPostOrder(&post_order, instruction.get(), + &added_instructions); } } - post_order.splice(post_order.end(), trace_instructions); + post_order.insert(post_order.end(), trace_instructions.begin(), + trace_instructions.end()); CHECK_EQ(instructions_.size(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } -std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList() +std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() const { tensorflow::gtl::FlatSet<HloComputation*> visited; - std::list<HloComputation*> post_order; + std::vector<HloComputation*> post_order; // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after @@ -648,7 +613,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() const { - const std::list<HloInstruction*> all = MakeInstructionPostOrder(); + const auto& all = MakeInstructionPostOrder(); auto result = MakeUnique<HloReachabilityMap>(all); std::vector<HloInstruction*> inputs; |