From de1696e9a818646fe6f200db42b150f1b7141900 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Aug 2018 06:36:02 -0700 Subject: Fix perfromance of HloComputation::ComputeChannelDependencies Previously it used an std::map containing std::vector's what added a large overhead to HloComputation::MakeInstructionPostOrder when a model contained a large number of channels. The new implementation replaced it with a FlatMap and an InlineVector what eliminates a large number of allocations and improves perfromance by a lot. PiperOrigin-RevId: 210531816 --- tensorflow/compiler/xla/service/hlo_computation.cc | 47 ++++++++++++---------- 1 file changed, 25 insertions(+), 22 deletions(-) (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc') diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 4a59380ed9..c2d0673f49 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -319,12 +319,12 @@ void ComputeComputationPostOrder( } } -enum State { kVisiting, kVisited }; +} // namespace -void ComputeInstructionPostOrder( - std::map> channel_dependency_map, +void HloComputation::ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) { + tensorflow::gtl::FlatMap* visited) const { std::vector dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -362,20 +362,22 @@ void ComputeInstructionPostOrder( // dependencies. switch (current->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = - channel_dependency_map[current->channel_id()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(current->channel_id()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = current->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } } break; @@ -386,11 +388,9 @@ void ComputeInstructionPostOrder( } } -} // namespace - -std::map> +HloComputation::ChannelDependencyMap HloComputation::ComputeChannelDependencies() const { - std::map> channel_dependency_map; + ChannelDependencyMap channel_dependency_map; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { case HloOpcode::kSend: { @@ -421,7 +421,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; - tensorflow::gtl::FlatMap visited; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -746,16 +746,19 @@ std::unique_ptr HloComputation::ComputeReachability() switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = channel_dependency_map[hlo->channel_id()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } } break; } -- cgit v1.2.3