diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 06:36:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 06:40:39 -0700 |
commit | de1696e9a818646fe6f200db42b150f1b7141900 (patch) | |
tree | d88ba24311c49774dfead232078b45b917368629 | |
parent | 57919740bf151cb6395aa60e30404ee9caa066d6 (diff) |
Fix perfromance of HloComputation::ComputeChannelDependencies
Previously it used an std::map containing std::vector's what added a
large overhead to HloComputation::MakeInstructionPostOrder when a model
contained a large number of channels.
The new implementation replaced it with a FlatMap and an InlineVector
what eliminates a large number of allocations and improves perfromance
by a lot.
PiperOrigin-RevId: 210531816
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 47 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 11 |
2 files changed, 34 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 4a59380ed9..c2d0673f49 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -319,12 +319,12 @@ void ComputeComputationPostOrder( } } -enum State { kVisiting, kVisited }; +} // namespace -void ComputeInstructionPostOrder( - std::map<int64, std::vector<HloInstruction*>> channel_dependency_map, +void HloComputation::ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector<HloInstruction*>* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) { + tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const { std::vector<HloInstruction*> dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -362,20 +362,22 @@ void ComputeInstructionPostOrder( // dependencies. switch (current->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = - channel_dependency_map[current->channel_id()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(current->channel_id()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = current->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } } break; @@ -386,11 +388,9 @@ void ComputeInstructionPostOrder( } } -} // namespace - -std::map<int64, std::vector<HloInstruction*>> +HloComputation::ChannelDependencyMap HloComputation::ComputeChannelDependencies() const { - std::map<int64, std::vector<HloInstruction*>> channel_dependency_map; + ChannelDependencyMap channel_dependency_map; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { case HloOpcode::kSend: { @@ -421,7 +421,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { std::vector<HloInstruction*> post_order; post_order.reserve(instruction_count()); std::vector<HloInstruction*> trace_instructions; - tensorflow::gtl::FlatMap<HloInstruction*, State> visited; + tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -746,16 +746,19 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = channel_dependency_map[hlo->channel_id()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } } break; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 8d9b694977..59016624f7 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -403,8 +403,15 @@ class HloComputation { // instructions. For send&recv pairs it means the send instruction and for // cross-replica-sum the union of the dependencies for all participating // instructions. - std::map<int64, std::vector<HloInstruction*>> ComputeChannelDependencies() - const; + using ChannelDependencyMap = + tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>; + ChannelDependencyMap ComputeChannelDependencies() const; + + enum VisitState { kVisiting, kVisited }; + void ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, + std::vector<HloInstruction*>* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const; string name_; int64 unique_id_; |