diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.cc | 53 |
1 files changed, 36 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cfe906d9c5..b3949f3a6d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) { HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, - bool uniquify_names) { + bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); @@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal( } } - if (uniquify_names) { + if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { + instruction->SetUniqueId(NewUniqueInstructionId()); + } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions - // for computations and instructions created later. + // for computations and instructions created later. Also, set the + // next_unique_id_ to the one greater than the max unique id of any + // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); + next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); + } + if (next_unique_id_ < computation->unique_id() + 1) { + next_unique_id_ = computation->unique_id() + 1; } } - // Pick unique IDs for each instruction. - for (auto* instruction : computation->instructions()) { - instruction->SetUniqueId(NewUniqueInstructionId()); - } - // Set unique id to this computation. - CHECK_NE(computation->root_instruction()->unique_id(), -1) - << "Root has no valid id: " << computation->ToString(); - computation->SetUniqueId(computation->root_instruction()->unique_id()); - computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr<HloComputation> computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr<HloComputation> computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( @@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(2) << "CreateFromProto()"; + XLA_VLOG_LINES(2, proto.DebugString()); + // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, - /*uniquify_names=*/false); + /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); - // Because we didn't uniquify the names, double-check that the instruction and - // computation names are unique from the proto. + // Because we didn't uniquify the names or the ids, double-check that the + // instruction and computation names and ids are unique from the proto. tensorflow::gtl::FlatSet<string> computation_names; tensorflow::gtl::FlatSet<string> instruction_names; + tensorflow::gtl::FlatSet<int> computation_ids; + tensorflow::gtl::FlatSet<int> instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); } } |