diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-09-26 09:10:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 09:18:25 -0700 |
commit | 01512356e10ab87887e3c7b69f9ed3e5a8397f76 (patch) | |
tree | 80f71f8f279a1d65960f66ed516e8d0f26751338 /tensorflow/compiler | |
parent | e9f76594ca1d7ea5317e0535d4a4bfffb269a1f9 (diff) |
[XLA] Don't use NumUniqueInstructionIds() as a proxy for instruction_count()
It used to be a reasonable proxy, but that's no longer the case. This is because GetUniqueId() in XlaBuilder uses a *global* (rather than a module-global) counter. Since HloModule::CreateFromProto no-longer uniquifies ids coming in from protos, the potentially very high IDs coming from GetUniqueId() become the module's next_unique_id.
There is another case of this in TuplePointsTo, that will be handled separately.
PiperOrigin-RevId: 214614576
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_memory_scheduler.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/logical_buffer_analysis.cc | 2 |
3 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index ad58833e4d..f7ec854d80 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2423,7 +2423,7 @@ template <typename Visitor> static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { - visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds()); + visitor->ReserveVisitStates(root->GetModule()->instruction_count()); // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>. // diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index c7ec88d450..6a4e766788 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -400,7 +400,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler( memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); + int64 total_hlos = computation.parent()->instruction_count(); tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users; tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index eaa09591b7..ec52a24d78 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -54,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() { // so reserve 10% more than the number of instructions to avoid frequent // resizes. logical_buffers_.clear(); - logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10); + logical_buffers_.reserve((module_->instruction_count() * 11) / 10); // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) |