aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h3
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc4
10 files changed, 29 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 8f1d2f0804..d514b99ed0 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -559,7 +559,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
HloModule* module) {
- std::list<HloComputation*> computations_topological_order =
+ const auto& computations_topological_order =
module->MakeComputationPostOrder();
tensorflow::gtl::FlatSet<const HloComputation*> resolved;
for (auto comp_it = computations_topological_order.rbegin();
@@ -742,7 +742,7 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
- std::list<HloComputation*> computations_topological_order =
+ const auto& computations_topological_order =
module->MakeComputationPostOrder();
// The first step is a forward pass (parameters to root), where we determine
// the potential candidate instructions to use bfloat16 in the outputs that
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index ef8bb030fb..74173a1685 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -263,46 +263,11 @@ void HloComputation::set_root_instruction(
namespace {
-// Helper class which computes the post order of an expression rooted at a
-// particular instruction.
-class InstructionPostOrderer : public DfsHloVisitorWithDefault {
- public:
- // added_instructions is the set of instructions which have already been
- // accounted for in the post order in previous invocations of
- // GetOrder. Without this mechanism, instructions which are predecessors of
- // multiple root instructions of the computation can be added to the post
- // order more than once.
- static std::list<HloInstruction*> GetOrder(
- HloInstruction* root,
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) {
- InstructionPostOrderer orderer(added_instructions);
- TF_CHECK_OK(root->Accept(&orderer));
- return std::move(orderer.post_order_);
- }
-
- private:
- explicit InstructionPostOrderer(
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions)
- : added_instructions_(added_instructions) {}
- ~InstructionPostOrderer() override {}
-
- Status DefaultAction(HloInstruction* hlo_instruction) override {
- if (added_instructions_->count(hlo_instruction) == 0) {
- post_order_.push_back(hlo_instruction);
- added_instructions_->insert(hlo_instruction);
- }
- return Status::OK();
- }
-
- std::list<HloInstruction*> post_order_;
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_;
-};
-
// Helper which builds a post order of the HLO call graph.
void ComputeComputationPostOrder(
HloComputation* computation,
tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::list<HloComputation*>* post_order) {
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -314,9 +279,9 @@ void ComputeComputationPostOrder(
}
}
-std::list<HloInstruction*> ComputeInstructionPostOrder(
- HloInstruction* root, tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
- std::list<HloInstruction*> post_order;
+void ComputeInstructionPostOrder(
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatSet<HloInstruction*>* visited) {
std::vector<std::pair<HloInstruction*, bool>> dfs_stack;
dfs_stack.emplace_back(root, false);
while (!dfs_stack.empty()) {
@@ -326,7 +291,7 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
if (!visited->insert(current.first).second) {
continue;
}
- post_order.push_back(current.first);
+ post_order->push_back(current.first);
} else {
if (visited->count(current.first)) {
dfs_stack.pop_back();
@@ -347,14 +312,14 @@ std::list<HloInstruction*> ComputeInstructionPostOrder(
}
}
}
- return post_order;
}
} // namespace
-std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
- std::list<HloInstruction*> post_order;
- std::list<HloInstruction*> trace_instructions;
+std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ std::vector<HloInstruction*> post_order;
+ post_order.reserve(instruction_count());
+ std::vector<HloInstruction*> trace_instructions;
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
@@ -363,21 +328,21 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- post_order.splice(
- post_order.end(),
- ComputeInstructionPostOrder(instruction.get(), &added_instructions));
+ ComputeInstructionPostOrder(&post_order, instruction.get(),
+ &added_instructions);
}
}
- post_order.splice(post_order.end(), trace_instructions);
+ post_order.insert(post_order.end(), trace_instructions.begin(),
+ trace_instructions.end());
CHECK_EQ(instructions_.size(), post_order.size())
<< "number of instructions does not match post order size";
return post_order;
}
-std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
+std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
tensorflow::gtl::FlatSet<HloComputation*> visited;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
@@ -648,7 +613,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
- const std::list<HloInstruction*> all = MakeInstructionPostOrder();
+ const auto& all = MakeInstructionPostOrder();
auto result = MakeUnique<HloReachabilityMap>(all);
std::vector<HloInstruction*> inputs;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 0da4a305f3..0f111a1a76 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -199,7 +199,7 @@ class HloComputation {
// Compute and return a post-order of the instructions in the computation. In
// this order, definitions of values always appear before their uses.
- std::list<HloInstruction*> MakeInstructionPostOrder() const;
+ std::vector<HloInstruction*> MakeInstructionPostOrder() const;
// Computes and returns the reachability between HLO instructions in the
// computation. The returned HloReachabilityMap is constructed such that
@@ -221,7 +221,7 @@ class HloComputation {
// transitively. The embedded computations are sorted such that if computation
// A calls computation B (eg, via a map instruction) then A will appear after
// B in the list.
- std::list<HloComputation*> MakeEmbeddedComputationsList() const;
+ std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
// Creates a fusion instruction containing the given instructions.
// `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index fcd723af14..8aa26bf520 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -85,8 +85,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
}
// Remove dead computations.
- std::list<HloComputation*> computations = module->MakeComputationPostOrder();
- for (auto* computation : computations) {
+ for (auto* computation : module->MakeComputationPostOrder()) {
if (live_computations.count(computation) == 0) {
TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation));
changed = true;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 9c59374b4a..11384c1456 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -451,7 +451,7 @@ int64 HloModule::instruction_count() const {
return n;
}
-std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
+std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
// First determine all root computations by building a set of nonroot
// computations (computations which are called by an instruction in the
// module).
@@ -469,7 +469,7 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
// order. This prevents duplication as an embedded computation may be called
// from two different root computations.
std::set<HloComputation*> added_computations;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
for (auto& computation : computations_) {
if (nonroot_computations.count(computation.get()) == 0) {
for (HloComputation* embedded_computation :
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 757e65bda2..5dc94e78e3 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -154,7 +154,7 @@ class HloModule {
// Compute and return a post order of all computations in the module. The sort
// is defined like so: if computation A has an instruction which calls
// computation B, then A will appear after B in the sort.
- std::list<HloComputation*> MakeComputationPostOrder() const;
+ std::vector<HloComputation*> MakeComputationPostOrder() const;
// Gets the computations in this module which aren't for fusion nodes.
//
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 5a0d1e264e..21a9b7291a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -277,7 +277,7 @@ Status HloModuleGroupUtil::VerifyComputations(
StatusOr<std::unique_ptr<HloReachabilityMap>>
HloModuleGroupUtil::ComputeReachability(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
- std::list<HloInstruction*> post_order;
+ std::vector<HloInstruction*> post_order;
auto visit_function =
[&](HloInstruction* instruction,
const std::vector<HloInstruction*>& instruction_group) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc
index 4738e46f8a..01b088a957 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace xla {
HloReachabilityMap::HloReachabilityMap(
- const std::list<HloInstruction*>& instructions)
+ tensorflow::gtl::ArraySlice<const HloInstruction*> instructions)
: size_(instructions.size()) {
bit_vectors_.reserve(size_);
for (const HloInstruction* hlo : instructions) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index 69bb2b3cee..48215d32a8 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -41,7 +41,8 @@ class HloReachabilityMap {
public:
// Sets up a graph with no edges and where the nodes correspond to the given
// instructions.
- explicit HloReachabilityMap(const std::list<HloInstruction*>& instructions);
+ explicit HloReachabilityMap(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
// Set the reachability set of 'instruction' to the union of the reachability
// sets of 'inputs'. Upon return, IsReachable(x, instruction) where
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index abedb4063d..d1c4c91b34 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -281,10 +281,8 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// map from HloInstruction* to the instruction's index in the vector. An
// instruction is "removed" from the vector by setting it's element to
// nullptr.
- std::list<HloInstruction*> post_order_list =
+ std::vector<HloInstruction*> post_order =
computation_->MakeInstructionPostOrder();
- std::vector<HloInstruction*> post_order(post_order_list.begin(),
- post_order_list.end());
tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
for (size_t i = 0; i < post_order.size(); ++i) {