aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc67
1 files changed, 16 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index ef8bb030fb..74173a1685 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -263,46 +263,11 @@ void HloComputation::set_root_instruction(
namespace {
-// Helper class which computes the post order of an expression rooted at a
-// particular instruction.
-class InstructionPostOrderer : public DfsHloVisitorWithDefault {
- public:
- // added_instructions is the set of instructions which have already been
- // accounted for in the post order in previous invocations of
- // GetOrder. Without this mechanism, instructions which are predecessors of
- // multiple root instructions of the computation can be added to the post
- // order more than once.
- static std::list<HloInstruction*> GetOrder(
- HloInstruction* root,
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) {
- InstructionPostOrderer orderer(added_instructions);
- TF_CHECK_OK(root->Accept(&orderer));
- return std::move(orderer.post_order_);
- }
-
- private:
- explicit InstructionPostOrderer(
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions)
- : added_instructions_(added_instructions) {}
- ~InstructionPostOrderer() override {}
-
- Status DefaultAction(HloInstruction* hlo_instruction) override {
- if (added_instructions_->count(hlo_instruction) == 0) {
- post_order_.push_back(hlo_instruction);
- added_instructions_->insert(hlo_instruction);
- }
- return Status::OK();
- }
-
- std::list<HloInstruction*> post_order_;
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_;
-};
-
// Helper which builds a post order of the HLO call graph.
void ComputeComputationPostOrder(
HloComputation* computation,
tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::list<HloComputation*>* post_order) {
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -314,9 +279,9 @@ void ComputeComputationPostOrder(
}
}
-std::list<HloInstruction*> ComputeInstructionPostOrder(
- HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
- std::list<HloInstruction*> post_order;
+void ComputeInstructionPostOrder(
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
std::vector<std::pair<HloInstruction*, bool>> dfs_stack;
dfs_stack.emplace_back(root, false);
while (!dfs_stack.empty()) {
@@ -326,7 +291,7 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
if (!visited->insert(current.first).second) {
continue;
}
- post_order.push_back(current.first);
+ post_order->push_back(current.first);
} else {
if (visited->count(current.first)) {
dfs_stack.pop_back();
@@ -347,14 +312,14 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
}
}
}
- return post_order;
}
} // namespace
-std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
- std::list<HloInstruction*> post_order;
- std::list<HloInstruction*> trace_instructions;
+std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ std::vector<HloInstruction*> post_order;
+ post_order.reserve(instruction_count());
+ std::vector<HloInstruction*> trace_instructions;
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
@@ -363,21 +328,21 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- post_order.splice(
- post_order.end(),
- ComputeInstructionPostOrder(instruction.get(), &added_instructions));
+ ComputeInstructionPostOrder(&post_order, instruction.get(),
+ &added_instructions);
}
}
- post_order.splice(post_order.end(), trace_instructions);
+ post_order.insert(post_order.end(), trace_instructions.begin(),
+ trace_instructions.end());
CHECK_EQ(instructions_.size(), post_order.size())
<< "number of instructions does not match post order size";
return post_order;
}
-std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
+std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
tensorflow::gtl::FlatSet<HloComputation*> visited;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
@@ -648,7 +613,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
- const std::list<HloInstruction*> all = MakeInstructionPostOrder();
+ const auto& all = MakeInstructionPostOrder();
auto result = MakeUnique<HloReachabilityMap>(all);
std::vector<HloInstruction*> inputs;