aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-06-15 11:10:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 11:13:08 -0700
commitb62d76d932f93ff324d2598cdeac792fa61135a4 (patch)
treef6c9dda35d6ae6263cb16984278fbebf93b46574 /tensorflow/compiler/xla/service/hlo_computation.cc
parent1ca4b6f797a168036e2708faf45753b333f467dc (diff)
[XLA] Switch PostOrder accessors to use std::vector instead of std::list.
std::list is just hilariously inefficient and the postorder list creation has been rewritten not to not depend on splicing anymore so there's no need for the list. While there remove the old unused postorder list creation code. PiperOrigin-RevId: 200743677
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc67
1 files changed, 16 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index ef8bb030fb..74173a1685 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -263,46 +263,11 @@ void HloComputation::set_root_instruction(
namespace {
-// Helper class which computes the post order of an expression rooted at a
-// particular instruction.
-class InstructionPostOrderer : public DfsHloVisitorWithDefault {
- public:
- // added_instructions is the set of instructions which have already been
- // accounted for in the post order in previous invocations of
- // GetOrder. Without this mechanism, instructions which are predecessors of
- // multiple root instructions of the computation can be added to the post
- // order more than once.
- static std::list<HloInstruction*> GetOrder(
- HloInstruction* root,
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) {
- InstructionPostOrderer orderer(added_instructions);
- TF_CHECK_OK(root->Accept(&orderer));
- return std::move(orderer.post_order_);
- }
-
- private:
- explicit InstructionPostOrderer(
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions)
- : added_instructions_(added_instructions) {}
- ~InstructionPostOrderer() override {}
-
- Status DefaultAction(HloInstruction* hlo_instruction) override {
- if (added_instructions_->count(hlo_instruction) == 0) {
- post_order_.push_back(hlo_instruction);
- added_instructions_->insert(hlo_instruction);
- }
- return Status::OK();
- }
-
- std::list<HloInstruction*> post_order_;
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_;
-};
-
// Helper which builds a post order of the HLO call graph.
void ComputeComputationPostOrder(
HloComputation* computation,
tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::list<HloComputation*>* post_order) {
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -314,9 +279,9 @@ void ComputeComputationPostOrder(
}
}
-std::list<HloInstruction*> ComputeInstructionPostOrder(
- HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
- std::list<HloInstruction*> post_order;
+void ComputeInstructionPostOrder(
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
std::vector<std::pair<HloInstruction*, bool>> dfs_stack;
dfs_stack.emplace_back(root, false);
while (!dfs_stack.empty()) {
@@ -326,7 +291,7 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
if (!visited->insert(current.first).second) {
continue;
}
- post_order.push_back(current.first);
+ post_order->push_back(current.first);
} else {
if (visited->count(current.first)) {
dfs_stack.pop_back();
@@ -347,14 +312,14 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
}
}
}
- return post_order;
}
} // namespace
-std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
- std::list<HloInstruction*> post_order;
- std::list<HloInstruction*> trace_instructions;
+std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ std::vector<HloInstruction*> post_order;
+ post_order.reserve(instruction_count());
+ std::vector<HloInstruction*> trace_instructions;
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
@@ -363,21 +328,21 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- post_order.splice(
- post_order.end(),
- ComputeInstructionPostOrder(instruction.get(), &added_instructions));
+ ComputeInstructionPostOrder(&post_order, instruction.get(),
+ &added_instructions);
}
}
- post_order.splice(post_order.end(), trace_instructions);
+ post_order.insert(post_order.end(), trace_instructions.begin(),
+ trace_instructions.end());
CHECK_EQ(instructions_.size(), post_order.size())
<< "number of instructions does not match post order size";
return post_order;
}
-std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
+std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
tensorflow::gtl::FlatSet<HloComputation*> visited;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
@@ -648,7 +613,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
- const std::list<HloInstruction*> all = MakeInstructionPostOrder();
+ const auto& all = MakeInstructionPostOrder();
auto result = MakeUnique<HloReachabilityMap>(all);
std::vector<HloInstruction*> inputs;