diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-09 09:30:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 09:40:01 -0700 |
commit | 87d8055c74a65ec9fb2a13f38e6e2c5d30b7e2e4 (patch) | |
tree | 41d1f00e201c7f108b04100c4dd08203ba01c37e /tensorflow/compiler/xla/service | |
parent | 92d533d19c44ab838a1f7954350fdafd62cfa889 (diff) |
Correctly pre-reserve visit state in HloInstruction::PostOrderDFS
Previously we pre-reserverd the visit state based on the number of
instructions but then started to index it with the instruction unique ID
what can be larger then the instruction count. This resulted in some
very expensive re-allocations what can be eliminated by reserving the
correctly sized buffer.
PiperOrigin-RevId: 216369849
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5c3908a9a4..050d28b289 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2474,7 +2474,7 @@ template <typename Visitor> static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { - visitor->ReserveVisitStates(root->GetModule()->instruction_count()); + visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds()); // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>. // |