aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-06-15 11:10:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-15 11:13:08 -0700
commitb62d76d932f93ff324d2598cdeac792fa61135a4 (patch)
treef6c9dda35d6ae6263cb16984278fbebf93b46574 /tensorflow
parent1ca4b6f797a168036e2708faf45753b333f467dc (diff)
[XLA] Switch PostOrder accessors to use std::vector instead of std::list.
std::list is just hilariously inefficient and the postorder list creation has been rewritten not to not depend on splicing anymore so there's no need for the list. While there remove the old unused postorder list creation code. PiperOrigin-RevId: 200743677
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h3
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc4
10 files changed, 29 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 8f1d2f0804..d514b99ed0 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -559,7 +559,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
HloModule* module) {
- std::list<HloComputation*> computations_topological_order =
+ const auto& computations_topological_order =
module->MakeComputationPostOrder();
tensorflow::gtl::FlatSet<const HloComputation*> resolved;
for (auto comp_it = computations_topological_order.rbegin();
@@ -742,7 +742,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
- std::list<HloComputation*> computations_topological_order =
+ const auto& computations_topological_order =
module->MakeComputationPostOrder();
// The first step is a forward pass (parameters to root), where we determine
// the potential candidate instructions to use bfloat16 in the outputs that
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index ef8bb030fb..74173a1685 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -263,46 +263,11 @@ void HloComputation::set_root_instruction(
namespace {
-// Helper class which computes the post order of an expression rooted at a
-// particular instruction.
-class InstructionPostOrderer : public DfsHloVisitorWithDefault {
- public:
- // added_instructions is the set of instructions which have already been
- // accounted for in the post order in previous invocations of
- // GetOrder. Without this mechanism, instructions which are predecessors of
- // multiple root instructions of the computation can be added to the post
- // order more than once.
- static std::list<HloInstruction*> GetOrder(
- HloInstruction* root,
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) {
- InstructionPostOrderer orderer(added_instructions);
- TF_CHECK_OK(root->Accept(&orderer));
- return std::move(orderer.post_order_);
- }
-
- private:
- explicit InstructionPostOrderer(
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions)
- : added_instructions_(added_instructions) {}
- ~InstructionPostOrderer() override {}
-
- Status DefaultAction(HloInstruction* hlo_instruction) override {
- if (added_instructions_->count(hlo_instruction) == 0) {
- post_order_.push_back(hlo_instruction);
- added_instructions_->insert(hlo_instruction);
- }
- return Status::OK();
- }
-
- std::list<HloInstruction*> post_order_;
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_;
-};
-
// Helper which builds a post order of the HLO call graph.
void ComputeComputationPostOrder(
HloComputation* computation,
tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::list<HloComputation*>* post_order) {
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -314,9 +279,9 @@ void ComputeComputationPostOrder(
}
}
-std::list<HloInstruction*> ComputeInstructionPostOrder(
- HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
- std::list<HloInstruction*> post_order;
+void ComputeInstructionPostOrder(
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
std::vector<std::pair<HloInstruction*, bool>> dfs_stack;
dfs_stack.emplace_back(root, false);
while (!dfs_stack.empty()) {
@@ -326,7 +291,7 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
if (!visited->insert(current.first).second) {
continue;
}
- post_order.push_back(current.first);
+ post_order->push_back(current.first);
} else {
if (visited->count(current.first)) {
dfs_stack.pop_back();
@@ -347,14 +312,14 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
}
}
}
- return post_order;
}
} // namespace
-std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
- std::list<HloInstruction*> post_order;
- std::list<HloInstruction*> trace_instructions;
+std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ std::vector<HloInstruction*> post_order;
+ post_order.reserve(instruction_count());
+ std::vector<HloInstruction*> trace_instructions;
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
@@ -363,21 +328,21 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- post_order.splice(
- post_order.end(),
- ComputeInstructionPostOrder(instruction.get(), &added_instructions));
+ ComputeInstructionPostOrder(&post_order, instruction.get(),
+ &added_instructions);
}
}
- post_order.splice(post_order.end(), trace_instructions);
+ post_order.insert(post_order.end(), trace_instructions.begin(),
+ trace_instructions.end());
CHECK_EQ(instructions_.size(), post_order.size())
<< "number of instructions does not match post order size";
return post_order;
}
-std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
+std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
tensorflow::gtl::FlatSet<HloComputation*> visited;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
@@ -648,7 +613,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
- const std::list<HloInstruction*> all = MakeInstructionPostOrder();
+ const auto& all = MakeInstructionPostOrder();
auto result = MakeUnique<HloReachabilityMap>(all);
std::vector<HloInstruction*> inputs;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 0da4a305f3..0f111a1a76 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -199,7 +199,7 @@ class HloComputation {
// Compute and return a post-order of the instructions in the computation. In
// this order, definitions of values always appear before their uses.
- std::list<HloInstruction*> MakeInstructionPostOrder() const;
+ std::vector<HloInstruction*> MakeInstructionPostOrder() const;
// Computes and returns the reachability between HLO instructions in the
// computation. The returned HloReachabilityMap is constructed such that
@@ -221,7 +221,7 @@ class HloComputation {
// transitively. The embedded computations are sorted such that if computation
// A calls computation B (eg, via a map instruction) then A will appear after
// B in the list.
- std::list<HloComputation*> MakeEmbeddedComputationsList() const;
+ std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
// Creates a fusion instruction containing the given instructions.
// `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index fcd723af14..8aa26bf520 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -85,8 +85,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
}
// Remove dead computations.
- std::list<HloComputation*> computations = module->MakeComputationPostOrder();
- for (auto* computation : computations) {
+ for (auto* computation : module->MakeComputationPostOrder()) {
if (live_computations.count(computation) == 0) {
TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation));
changed = true;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 9c59374b4a..11384c1456 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -451,7 +451,7 @@ int64 HloModule::instruction_count() const {
return n;
}
-std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
+std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
// First determine all root computations by building a set of nonroot
// computations (computations which are called by an instruction in the
// module).
@@ -469,7 +469,7 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
// order. This prevents duplication as an embedded computation may be called
// from two different root computations.
std::set<HloComputation*> added_computations;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
for (auto& computation : computations_) {
if (nonroot_computations.count(computation.get()) == 0) {
for (HloComputation* embedded_computation :
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 757e65bda2..5dc94e78e3 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -154,7 +154,7 @@ class HloModule {
// Compute and return a post order of all computations in the module. The sort
// is defined like so: if computation A has an instruction which calls
// computation B, then A will appear after B in the sort.
- std::list<HloComputation*> MakeComputationPostOrder() const;
+ std::vector<HloComputation*> MakeComputationPostOrder() const;
// Gets the computations in this module which aren't for fusion nodes.
//
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 5a0d1e264e..21a9b7291a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -277,7 +277,7 @@ Status HloModuleGroupUtil::VerifyComputations(
StatusOr<std::unique_ptr<HloReachabilityMap>>
HloModuleGroupUtil::ComputeReachability(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
- std::list<HloInstruction*> post_order;
+ std::vector<HloInstruction*> post_order;
auto visit_function =
[&](HloInstruction* instruction,
const std::vector<HloInstruction*>& instruction_group) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc
index 4738e46f8a..01b088a957 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace xla {
HloReachabilityMap::HloReachabilityMap(
- const std::list<HloInstruction*>& instructions)
+ tensorflow::gtl::ArraySlice<const HloInstruction*> instructions)
: size_(instructions.size()) {
bit_vectors_.reserve(size_);
for (const HloInstruction* hlo : instructions) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index 69bb2b3cee..48215d32a8 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -41,7 +41,8 @@ class HloReachabilityMap {
public:
// Sets up a graph with no edges and where the nodes correspond to the given
// instructions.
- explicit HloReachabilityMap(const std::list<HloInstruction*>& instructions);
+ explicit HloReachabilityMap(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
// Set the reachability set of 'instruction' to the union of the reachability
// sets of 'inputs'. Upon return, IsReachable(x, instruction) where
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index abedb4063d..d1c4c91b34 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -281,10 +281,8 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// map from HloInstruction* to the instruction's index in the vector. An
// instruction is "removed" from the vector by setting it's element to
// nullptr.
- std::list<HloInstruction*> post_order_list =
+ std::vector<HloInstruction*> post_order =
computation_->MakeInstructionPostOrder();
- std::vector<HloInstruction*> post_order(post_order_list.begin(),
- post_order_list.end());
tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
for (size_t i = 0; i < post_order.size(); ++i) {