aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 06:36:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 06:40:39 -0700
commitde1696e9a818646fe6f200db42b150f1b7141900 (patch)
treed88ba24311c49774dfead232078b45b917368629
parent57919740bf151cb6395aa60e30404ee9caa066d6 (diff)
Fix perfromance of HloComputation::ComputeChannelDependencies
Previously it used an std::map containing std::vector's what added a large overhead to HloComputation::MakeInstructionPostOrder when a model contained a large number of channels. The new implementation replaced it with a FlatMap and an InlineVector what eliminates a large number of allocations and improves perfromance by a lot. PiperOrigin-RevId: 210531816
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc47
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h11
2 files changed, 34 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 4a59380ed9..c2d0673f49 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -319,12 +319,12 @@ void ComputeComputationPostOrder(
}
}
-enum State { kVisiting, kVisited };
+} // namespace
-void ComputeInstructionPostOrder(
- std::map<int64, std::vector<HloInstruction*>> channel_dependency_map,
+void HloComputation::ComputeInstructionPostOrder(
+ const HloComputation::ChannelDependencyMap& channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
- tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const {
std::vector<HloInstruction*> dfs_stack;
dfs_stack.push_back(root);
while (!dfs_stack.empty()) {
@@ -362,20 +362,22 @@ void ComputeInstructionPostOrder(
// dependencies.
switch (current->opcode()) {
case HloOpcode::kRecvDone: {
- const auto& dependencies =
- channel_dependency_map[current->channel_id()];
- for (HloInstruction* op : dependencies) {
- dfs_stack.emplace_back(op);
+ auto it = channel_dependency_map.find(current->channel_id());
+ if (it != channel_dependency_map.end()) {
+ for (HloInstruction* op : it->second) {
+ dfs_stack.emplace_back(op);
+ }
}
break;
}
case HloOpcode::kCrossReplicaSum: {
auto all_reduce_id = current->all_reduce_id();
if (all_reduce_id) {
- const auto& dependencies =
- channel_dependency_map[all_reduce_id.value()];
- for (HloInstruction* op : dependencies) {
- dfs_stack.emplace_back(op);
+ auto it = channel_dependency_map.find(all_reduce_id.value());
+ if (it != channel_dependency_map.end()) {
+ for (HloInstruction* op : it->second) {
+ dfs_stack.emplace_back(op);
+ }
}
}
break;
@@ -386,11 +388,9 @@ void ComputeInstructionPostOrder(
}
}
-} // namespace
-
-std::map<int64, std::vector<HloInstruction*>>
+HloComputation::ChannelDependencyMap
HloComputation::ComputeChannelDependencies() const {
- std::map<int64, std::vector<HloInstruction*>> channel_dependency_map;
+ ChannelDependencyMap channel_dependency_map;
for (const auto& instruction : instructions_) {
switch (instruction->opcode()) {
case HloOpcode::kSend: {
@@ -421,7 +421,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
- tensorflow::gtl::FlatMap<HloInstruction*, State> visited;
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
@@ -746,16 +746,19 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
switch (hlo->opcode()) {
case HloOpcode::kRecvDone: {
- const auto& dependencies = channel_dependency_map[hlo->channel_id()];
- absl::c_copy(dependencies, std::back_inserter(inputs));
+ auto it = channel_dependency_map.find(hlo->channel_id());
+ if (it != channel_dependency_map.end()) {
+ absl::c_copy(it->second, std::back_inserter(inputs));
+ }
break;
}
case HloOpcode::kCrossReplicaSum: {
auto all_reduce_id = hlo->all_reduce_id();
if (all_reduce_id) {
- const auto& dependencies =
- channel_dependency_map[all_reduce_id.value()];
- absl::c_copy(dependencies, std::back_inserter(inputs));
+ auto it = channel_dependency_map.find(all_reduce_id.value());
+ if (it != channel_dependency_map.end()) {
+ absl::c_copy(it->second, std::back_inserter(inputs));
+ }
}
break;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 8d9b694977..59016624f7 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -403,8 +403,15 @@ class HloComputation {
// instructions. For send&recv pairs it means the send instruction and for
// cross-replica-sum the union of the dependencies for all participating
// instructions.
- std::map<int64, std::vector<HloInstruction*>> ComputeChannelDependencies()
- const;
+ using ChannelDependencyMap =
+ tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>;
+ ChannelDependencyMap ComputeChannelDependencies() const;
+
+ enum VisitState { kVisiting, kVisited };
+ void ComputeInstructionPostOrder(
+ const HloComputation::ChannelDependencyMap& channel_dependency_map,
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const;
string name_;
int64 unique_id_;