diff options
author | Mark Heffernan <meheff@google.com> | 2018-09-12 13:19:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 13:23:12 -0700 |
commit | 5d1de24583aabeb2cb883ab197ae2b8d5446c565 (patch) | |
tree | 8e1227ad724f3da4413ce51ef4d39925b7ff226a | |
parent | 3fb474713b27552eba1943bb4172e54ad2dd13bc (diff) |
Preserve unique ids when serializing/deserializing HLO protos.
Re-assigning unique IDs broke serialization of HloSchedule, and keeping IDs stable improves the fidelity of the proto serialization. This change requires that instructions in HLO module protos have valid, module-scope-unique ids so change the XLA builder to hand out module-scope-unique ids. Previously, instruction ids were only unique in the computation scope.
PiperOrigin-RevId: 212692339
-rw-r--r-- | tensorflow/compiler/aot/tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tests/tfcompile_test.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 42 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.h | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.cc | 53 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_test.cc | 94 |
11 files changed, 196 insertions, 47 deletions
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 8d94f5495c..7a0932d44d 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -231,6 +231,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index dd2b151098..7ac90fb8a9 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) { string hlo_profile_as_string = xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), /*clock_rate_ghz=*/1.0); - VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; + + // Strip away identifier details from the profile string to avoid this test + // being a change detector for xla internals. Identifiers such as '%dot.0.7' + // just become '%dot'. + RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1"); + VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string; std::vector<string> hlo_profile_lines = absl::StrSplit(hlo_profile_as_string, '\n'); @@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto add_profile_line = HasSubstr( - "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto tuple_profile_line = HasSubstr( - "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); - auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); - auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, " + "f32[2,2]{1,0} %add)"); + auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)"); EXPECT_THAT(hlo_profile_lines, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 100b10cd83..72b17d04fc 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -604,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { auto instr1 = c1.instructions(j); auto instr2 = c2.instructions(j); instr1.clear_name(); + instr1.clear_id(); + instr1.clear_operand_ids(); instr2.clear_name(); - // The names of instructions were uniquified by the XlaBuilder, the rest - // of the fields should be identical. + instr2.clear_id(); + instr2.clear_operand_ids(); + // The names of instructions were uniquified by the XlaBuilder and the + // unique ids may be different, the rest of the fields should be + // identical. string str1, str2; + LOG(INFO) << "instr1 = " << instr1.DebugString(); + LOG(INFO) << "instr2 = " << instr2.DebugString(); instr1.AppendPartialToString(&str1); instr2.AppendPartialToString(&str2); EXPECT_EQ(str1, str2); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 8951e93ee6..95ff6432a5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn( StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto, + LookUpInstructionByHandle(root_id)); ProgramShape program_shape; - *program_shape.mutable_result() = instructions_[root_id].shape(); + *program_shape.mutable_result() = root_proto->shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, return; } - CHECK(op_handle < instructions_.size() && op_handle >= 0); - - const HloInstructionProto& instr = instructions_[op_handle]; + const HloInstructionProto& instr = + *(LookUpInstructionByHandle(op_handle).ValueOrDie()); const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); switch (opcode) { default: @@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) { // Clear data held by this builder. this->instructions_.clear(); + this->handle_to_index_.clear(); this->embedded_.clear(); this->parameter_numbers_.clear(); @@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph( *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is - // also a valid denpendency order). The related ops will be added to the + // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set<int64> related_ops; tensorflow::gtl::FlatSet<int64> related_calls; // Related computations. @@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph( worklist.push(root->id()); related_ops.insert(root->id()); while (!worklist.empty()) { - int64 node = worklist.front(); + int64 handle = worklist.front(); worklist.pop(); - for (int64 id : instructions_[node].operand_ids()) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(handle)); + for (int64 id : instr_proto->operand_ids()) { if (related_ops.insert(id).second) { worklist.push(id); } } - for (int64 called_id : instructions_[node].called_computation_ids()) { + for (int64 called_id : instr_proto->called_computation_ids()) { related_calls.insert(called_id); } } @@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph( // Add related ops to the computation. for (int64 id : related_ops) { auto* instr = entry.add_instructions(); - *instr = instructions_[id]; + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, + LookUpInstructionByHandle(id)); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = StrCat(instr->name(), ".", entry.id(), ".", instr->id()); @@ -2415,7 +2420,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr, absl::Span<const XlaOp> operands) { TF_RETURN_IF_ERROR(first_error_); - const int64 handle = instructions_.size(); + const int64 handle = GetUniqueId(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { @@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr, *instr.mutable_sharding() = *sharding_; } - instructions_.push_back(instr); + handle_to_index_[handle] = instructions_.size(); + instructions_.push_back(std::move(instr)); XlaOp op(handle, this); return op; @@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction( op.handle(), op.builder_->name(), this->name()); } - if (op.handle() >= instructions_.size() || op.handle() < 0) { - return InvalidArgument("no XlaOp value %d", op.handle()); + return LookUpInstructionByHandle(op.handle()); +} + +StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle( + int64 handle) const { + auto it = handle_to_index_.find(handle); + if (it == handle_to_index_.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); } - return &instructions_[op.handle()]; + return &instructions_[it->second]; } // Enqueues a "retrieve parameter value" instruction for a parameter that was diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 833eafcf85..d0c59fa6f2 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" @@ -955,6 +956,8 @@ class XlaBuilder { HloInstructionProto* instr); StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const; + StatusOr<const HloInstructionProto*> LookUpInstructionByHandle( + int64 handle) const; // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -1024,6 +1027,10 @@ class XlaBuilder { // The instructions of this computation. std::vector<HloInstructionProto> instructions_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + tensorflow::gtl::FlatMap<int64, int64> handle_to_index_; + // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of // that XlaComputation. diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d2bea9c8da..fc259a6ca2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1963,6 +1963,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", + ":hlo_memory_scheduler", ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 233d2199d1..8c6903d766 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -562,9 +562,11 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + auto computation = absl::WrapUnique( + new HloComputation(proto.name(), parameter_count, &instructions, root, + /*fusion_instruction=*/nullptr)); + computation->unique_id_ = proto.id(); + return std::move(computation); } void HloComputation::FuseInstructionsInto( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 85fa3ce964..e905f2983a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cfe906d9c5..b3949f3a6d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) { HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, - bool uniquify_names) { + bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); @@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal( } } - if (uniquify_names) { + if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { + instruction->SetUniqueId(NewUniqueInstructionId()); + } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions - // for computations and instructions created later. + // for computations and instructions created later. Also, set the + // next_unique_id_ to the one greater than the max unique id of any + // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); + next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); + } + if (next_unique_id_ < computation->unique_id() + 1) { + next_unique_id_ = computation->unique_id() + 1; } } - // Pick unique IDs for each instruction. - for (auto* instruction : computation->instructions()) { - instruction->SetUniqueId(NewUniqueInstructionId()); - } - // Set unique id to this computation. - CHECK_NE(computation->root_instruction()->unique_id(), -1) - << "Root has no valid id: " << computation->ToString(); - computation->SetUniqueId(computation->root_instruction()->unique_id()); - computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr<HloComputation> computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr<HloComputation> computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( @@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(2) << "CreateFromProto()"; + XLA_VLOG_LINES(2, proto.DebugString()); + // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, - /*uniquify_names=*/false); + /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); - // Because we didn't uniquify the names, double-check that the instruction and - // computation names are unique from the proto. + // Because we didn't uniquify the names or the ids, double-check that the + // instruction and computation names and ids are unique from the proto. tensorflow::gtl::FlatSet<string> computation_names; tensorflow::gtl::FlatSet<string> instruction_names; + tensorflow::gtl::FlatSet<int> computation_ids; + tensorflow::gtl::FlatSet<int> instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); } } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 26fd1b2438..3bc2d13781 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -253,7 +253,7 @@ class HloModule { private: HloComputation* AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, - bool uniquify_names); + bool uniquify_identifiers); const string name_; HloModuleConfig config_; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 400bd4d947..6243943420 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -253,6 +254,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { op::Broadcast(), op::Multiply(), op::Add())); } +TEST_F(HloModuleTest, ProtoSerializationPreservesIds) { + // Verify that serializing then deserializing an HLO proto preserves the + // unique IDs of the instruction and module. + const string text = + R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + + // Perform various transformations on the graph: + // + // * clone the reduction function + // * replace use of reduction function with the clone. + // * add a random instruction to the entry computation. + // + // This will create instruction and computation IDs which are interesting: + // not consecutive and not densely packed. + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + HloComputation* reduction = root->to_apply(); + HloComputation* reduction_clone = + module->AddEmbeddedComputation(reduction->Clone()); + root->set_to_apply(reduction_clone); + TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction)); + HloInstruction* negate = entry->AddInstruction( + HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root)); + entry->set_root_instruction(negate); + + // Schedule the transformed module, this verifies that the serialized schedule + // is robust against non-consecutive IDs as well (b/114712358). + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + HloMemoryScheduler scheduler(size_fn); + TF_ASSERT_OK(scheduler.Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + + // Serialize and deserialize and verify that the instruction and computations + // unique ids are the same. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + + // The module IDs should *not* be the same because module ids must be globally + // unique. + EXPECT_NE(module->unique_id(), module_copy->unique_id()); + + // Verify that the computations and instructions all have the same unique id. + auto computation_copy_it = module_copy->computations().begin(); + for (const HloComputation* computation_orig : module->computations()) { + const HloComputation* computation_copy = *computation_copy_it++; + EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id()) + << absl::StrFormat( + "ID of original computation %s != ID of deserialized " + "computation %s: %d != %d", + computation_orig->name(), computation_copy->name(), + computation_orig->unique_id(), computation_copy->unique_id()); + + auto instruction_copy_it = computation_copy->instructions().begin(); + for (const HloInstruction* instruction_orig : + computation_orig->instructions()) { + const HloInstruction* instruction_copy = *instruction_copy_it++; + EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id()) + << absl::StrFormat( + "ID of original instruction %s != ID of deserialized " + "instruction %s: %d != %d", + instruction_orig->name(), instruction_copy->name(), + instruction_orig->unique_id(), instruction_copy->unique_id()); + } + } + + // Verify that the next unique ID which the module would have handed out is + // greater than the unique id of any instruction. + int next_id = module_copy->NewUniqueInstructionId(); + for (const HloComputation* computation : module_copy->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + EXPECT_GT(next_id, instruction->unique_id()); + } + } +} + } // namespace } // namespace xla |