aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 06:36:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 06:40:39 -0700
commitde1696e9a818646fe6f200db42b150f1b7141900 (patch)
treed88ba24311c49774dfead232078b45b917368629 /tensorflow/compiler/xla/service/hlo_computation.h
parent57919740bf151cb6395aa60e30404ee9caa066d6 (diff)
Fix perfromance of HloComputation::ComputeChannelDependencies
Previously it used an std::map containing std::vector's what added a large overhead to HloComputation::MakeInstructionPostOrder when a model contained a large number of channels. The new implementation replaced it with a FlatMap and an InlineVector what eliminates a large number of allocations and improves perfromance by a lot. PiperOrigin-RevId: 210531816
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 8d9b694977..59016624f7 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -403,8 +403,15 @@ class HloComputation {
// instructions. For send&recv pairs it means the send instruction and for
// cross-replica-sum the union of the dependencies for all participating
// instructions.
- std::map<int64, std::vector<HloInstruction*>> ComputeChannelDependencies()
- const;
+ using ChannelDependencyMap =
+ tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>;
+ ChannelDependencyMap ComputeChannelDependencies() const;
+
+ enum VisitState { kVisiting, kVisited };
+ void ComputeInstructionPostOrder(
+ const HloComputation::ChannelDependencyMap& channel_dependency_map,
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const;
string name_;
int64 unique_id_;