aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc53
1 files changed, 36 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cfe906d9c5..b3949f3a6d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names) {
+ bool uniquify_identifiers) {
if (is_entry) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
@@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
}
}
- if (uniquify_names) {
+ if (uniquify_identifiers) {
computation->UniquifyName(&computation_name_uniquer_);
for (auto* instruction : computation->instructions()) {
instruction->UniquifyName(&instruction_name_uniquer_);
}
+
+ // Pick unique IDs for each instruction.
+ for (auto* instruction : computation->instructions()) {
+ instruction->SetUniqueId(NewUniqueInstructionId());
+ }
+ // Set unique id to this computation.
+ CHECK_NE(computation->root_instruction()->unique_id(), -1)
+ << "Root has no valid id: " << computation->ToString();
+ computation->SetUniqueId(computation->root_instruction()->unique_id());
} else {
// Don't uniquify the names of the computation or instruction, but we must
// run the names through the uniquifiers to prevent future name collisions
- // for computations and instructions created later.
+ // for computations and instructions created later. Also, set the
+ // next_unique_id_ to the one greater than the max unique id of any
+ // instruction (or the computation) to avoid ID collisions.
computation_name_uniquer_.GetUniqueName(computation->name());
for (auto* instruction : computation->instructions()) {
instruction_name_uniquer_.GetUniqueName(instruction->name());
+ next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
+ }
+ if (next_unique_id_ < computation->unique_id() + 1) {
+ next_unique_id_ = computation->unique_id() + 1;
}
}
- // Pick unique IDs for each instruction.
- for (auto* instruction : computation->instructions()) {
- instruction->SetUniqueId(NewUniqueInstructionId());
- }
- // Set unique id to this computation.
- CHECK_NE(computation->root_instruction()->unique_id(), -1)
- << "Root has no valid id: " << computation->ToString();
- computation->SetUniqueId(computation->root_instruction()->unique_id());
-
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
@@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
void HloModule::ReplaceComputations(
@@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ VLOG(2) << "CreateFromProto()";
+ XLA_VLOG_LINES(2, proto.DebugString());
+
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
module->AddComputationInternal(std::move(computation), is_entry,
- /*uniquify_names=*/false);
+ /*uniquify_identifiers=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
- // Because we didn't uniquify the names, double-check that the instruction and
- // computation names are unique from the proto.
+ // Because we didn't uniquify the names or the ids, double-check that the
+ // instruction and computation names and ids are unique from the proto.
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
+ tensorflow::gtl::FlatSet<int> computation_ids;
+ tensorflow::gtl::FlatSet<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());
+
+ TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
+ << "Computation id is not unique: " << computation->unique_id();
+ computation_ids.insert(computation->unique_id());
for (HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
<< "Instruction name is not unique: " << instruction->name();
instruction_names.insert(instruction->name());
+
+ TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
+ << "Instruction id is not unique: " << instruction->unique_id();
+ instruction_ids.insert(instruction->unique_id());
}
}