diff options
10 files changed, 29 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 8f1d2f0804..d514b99ed0 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -559,7 +559,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { - std::list<HloComputation*> computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); tensorflow::gtl::FlatSet<const HloComputation*> resolved; for (auto comp_it = computations_topological_order.rbegin(); @@ -742,7 +742,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); - std::list<HloComputation*> computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); // The first step is a forward pass (parameters to root), where we determine // the potential candidate instructions to use bfloat16 in the outputs that diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ef8bb030fb..74173a1685 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -263,46 +263,11 @@ void HloComputation::set_root_instruction( namespace { -// Helper class which computes the post order of an expression rooted at a -// particular instruction. -class InstructionPostOrderer : public DfsHloVisitorWithDefault { - public: - // added_instructions is the set of instructions which have already been - // accounted for in the post order in previous invocations of - // GetOrder. Without this mechanism, instructions which are predecessors of - // multiple root instructions of the computation can be added to the post - // order more than once. - static std::list<HloInstruction*> GetOrder( - HloInstruction* root, - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) { - InstructionPostOrderer orderer(added_instructions); - TF_CHECK_OK(root->Accept(&orderer)); - return std::move(orderer.post_order_); - } - - private: - explicit InstructionPostOrderer( - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) - : added_instructions_(added_instructions) {} - ~InstructionPostOrderer() override {} - - Status DefaultAction(HloInstruction* hlo_instruction) override { - if (added_instructions_->count(hlo_instruction) == 0) { - post_order_.push_back(hlo_instruction); - added_instructions_->insert(hlo_instruction); - } - return Status::OK(); - } - - std::list<HloInstruction*> post_order_; - tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_; -}; - // Helper which builds a post order of the HLO call graph. void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet<HloComputation*>* visited, - std::list<HloComputation*>* post_order) { + std::vector<HloComputation*>* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -314,9 +279,9 @@ void ComputeComputationPostOrder( } } -std::list<HloInstruction*> ComputeInstructionPostOrder( - HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) { - std::list<HloInstruction*> post_order; +void ComputeInstructionPostOrder( + std::vector<HloInstruction*>* post_order, HloInstruction* root, + tensorflow::gtl::FlatSet<HloInstruction*>* visited) { std::vector<std::pair<HloInstruction*, bool>> dfs_stack; dfs_stack.emplace_back(root, false); while (!dfs_stack.empty()) { @@ -326,7 +291,7 @@ std::list<HloInstruction*> ComputeInstructionPostOrder( if (!visited->insert(current.first).second) { continue; } - post_order.push_back(current.first); + post_order->push_back(current.first); } else { if (visited->count(current.first)) { dfs_stack.pop_back(); @@ -347,14 +312,14 @@ std::list<HloInstruction*> ComputeInstructionPostOrder( } } } - return post_order; } } // namespace -std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { - std::list<HloInstruction*> post_order; - std::list<HloInstruction*> trace_instructions; +std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { + std::vector<HloInstruction*> post_order; + post_order.reserve(instruction_count()); + std::vector<HloInstruction*> trace_instructions; tensorflow::gtl::FlatSet<HloInstruction*> added_instructions; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { @@ -363,21 +328,21 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice( - post_order.end(), - ComputeInstructionPostOrder(instruction.get(), &added_instructions)); + ComputeInstructionPostOrder(&post_order, instruction.get(), + &added_instructions); } } - post_order.splice(post_order.end(), trace_instructions); + post_order.insert(post_order.end(), trace_instructions.begin(), + trace_instructions.end()); CHECK_EQ(instructions_.size(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } -std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList() +std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() const { tensorflow::gtl::FlatSet<HloComputation*> visited; - std::list<HloComputation*> post_order; + std::vector<HloComputation*> post_order; // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after @@ -648,7 +613,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() const { - const std::list<HloInstruction*> all = MakeInstructionPostOrder(); + const auto& all = MakeInstructionPostOrder(); auto result = MakeUnique<HloReachabilityMap>(all); std::vector<HloInstruction*> inputs; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0da4a305f3..0f111a1a76 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -199,7 +199,7 @@ class HloComputation { // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. - std::list<HloInstruction*> MakeInstructionPostOrder() const; + std::vector<HloInstruction*> MakeInstructionPostOrder() const; // Computes and returns the reachability between HLO instructions in the // computation. The returned HloReachabilityMap is constructed such that @@ -221,7 +221,7 @@ class HloComputation { // transitively. The embedded computations are sorted such that if computation // A calls computation B (eg, via a map instruction) then A will appear after // B in the list. - std::list<HloComputation*> MakeEmbeddedComputationsList() const; + std::vector<HloComputation*> MakeEmbeddedComputationsList() const; // Creates a fusion instruction containing the given instructions. // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index fcd723af14..8aa26bf520 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -85,8 +85,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) { } // Remove dead computations. - std::list<HloComputation*> computations = module->MakeComputationPostOrder(); - for (auto* computation : computations) { + for (auto* computation : module->MakeComputationPostOrder()) { if (live_computations.count(computation) == 0) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 9c59374b4a..11384c1456 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -451,7 +451,7 @@ int64 HloModule::instruction_count() const { return n; } -std::list<HloComputation*> HloModule::MakeComputationPostOrder() const { +std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). @@ -469,7 +469,7 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const { // order. This prevents duplication as an embedded computation may be called // from two different root computations. std::set<HloComputation*> added_computations; - std::list<HloComputation*> post_order; + std::vector<HloComputation*> post_order; for (auto& computation : computations_) { if (nonroot_computations.count(computation.get()) == 0) { for (HloComputation* embedded_computation : diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 757e65bda2..5dc94e78e3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -154,7 +154,7 @@ class HloModule { // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. - std::list<HloComputation*> MakeComputationPostOrder() const; + std::vector<HloComputation*> MakeComputationPostOrder() const; // Gets the computations in this module which aren't for fusion nodes. // diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 5a0d1e264e..21a9b7291a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -277,7 +277,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr<std::unique_ptr<HloReachabilityMap>> HloModuleGroupUtil::ComputeReachability( tensorflow::gtl::ArraySlice<HloComputation*> computations) { - std::list<HloInstruction*> post_order; + std::vector<HloInstruction*> post_order; auto visit_function = [&](HloInstruction* instruction, const std::vector<HloInstruction*>& instruction_group) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 4738e46f8a..01b088a957 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - const std::list<HloInstruction*>& instructions) + tensorflow::gtl::ArraySlice<const HloInstruction*> instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 69bb2b3cee..48215d32a8 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -41,7 +41,8 @@ class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given // instructions. - explicit HloReachabilityMap(const std::list<HloInstruction*>& instructions); + explicit HloReachabilityMap( + tensorflow::gtl::ArraySlice<const HloInstruction*> instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index abedb4063d..d1c4c91b34 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -281,10 +281,8 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { // map from HloInstruction* to the instruction's index in the vector. An // instruction is "removed" from the vector by setting it's element to // nullptr. - std::list<HloInstruction*> post_order_list = + std::vector<HloInstruction*> post_order = computation_->MakeInstructionPostOrder(); - std::vector<HloInstruction*> post_order(post_order_list.begin(), - post_order_list.end()); tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { |