aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-12 13:19:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 13:23:12 -0700
commit5d1de24583aabeb2cb883ab197ae2b8d5446c565 (patch)
tree8e1227ad724f3da4413ce51ef4d39925b7ff226a
parent3fb474713b27552eba1943bb4172e54ad2dd13bc (diff)
Preserve unique ids when serializing/deserializing HLO protos.
Re-assigning unique IDs broke serialization of HloSchedule, and keeping IDs stable improves the fidelity of the proto serialization. This change requires that instructions in HLO module protos have valid, module-scope-unique ids so change the XLA builder to hand out module-scope-unique ids. Previously, instruction ids were only unique in the computation scope. PiperOrigin-RevId: 212692339
-rw-r--r--tensorflow/compiler/aot/tests/BUILD1
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc23
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc11
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc42
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h7
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc53
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc94
11 files changed, 196 insertions, 47 deletions
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 8d94f5495c..7a0932d44d 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -231,6 +231,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index dd2b151098..7ac90fb8a9 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) {
string hlo_profile_as_string =
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
/*clock_rate_ghz=*/1.0);
- VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
+ VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
+
+ // Strip away identifier details from the profile string to avoid this test
+ // being a change detector for xla internals. Identifiers such as '%dot.0.7'
+ // just become '%dot'.
+ RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
+ VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
std::vector<string> hlo_profile_lines =
absl::StrSplit(hlo_profile_as_string, '\n');
@@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
- "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
- "%arg1.0.1)");
+ "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto add_profile_line = HasSubstr(
- "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
- "%arg1.0.1)");
+ "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto tuple_profile_line = HasSubstr(
- "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
- "%dot.0.4, f32[2,2]{1,0} %add.0.6)");
- auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
- auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
+ "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
+ "f32[2,2]{1,0} %add)");
+ auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
+ auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 100b10cd83..72b17d04fc 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -604,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
auto instr1 = c1.instructions(j);
auto instr2 = c2.instructions(j);
instr1.clear_name();
+ instr1.clear_id();
+ instr1.clear_operand_ids();
instr2.clear_name();
- // The names of instructions were uniquified by the XlaBuilder, the rest
- // of the fields should be identical.
+ instr2.clear_id();
+ instr2.clear_operand_ids();
+ // The names of instructions were uniquified by the XlaBuilder and the
+ // unique ids may be different, the rest of the fields should be
+ // identical.
string str1, str2;
+ LOG(INFO) << "instr1 = " << instr1.DebugString();
+ LOG(INFO) << "instr2 = " << instr2.DebugString();
instr1.AppendPartialToString(&str1);
instr2.AppendPartialToString(&str2);
EXPECT_EQ(str1, str2);
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 8951e93ee6..95ff6432a5 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_);
- TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
+ LookUpInstructionByHandle(root_id));
ProgramShape program_shape;
- *program_shape.mutable_result() = instructions_[root_id].shape();
+ *program_shape.mutable_result() = root_proto->shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
return;
}
- CHECK(op_handle < instructions_.size() && op_handle >= 0);
-
- const HloInstructionProto& instr = instructions_[op_handle];
+ const HloInstructionProto& instr =
+ *(LookUpInstructionByHandle(op_handle).ValueOrDie());
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
switch (opcode) {
default:
@@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
// Clear data held by this builder.
this->instructions_.clear();
+ this->handle_to_index_.clear();
this->embedded_.clear();
this->parameter_numbers_.clear();
@@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is
- // also a valid denpendency order). The related ops will be added to the
+ // also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
@@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
worklist.push(root->id());
related_ops.insert(root->id());
while (!worklist.empty()) {
- int64 node = worklist.front();
+ int64 handle = worklist.front();
worklist.pop();
- for (int64 id : instructions_[node].operand_ids()) {
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+ LookUpInstructionByHandle(handle));
+ for (int64 id : instr_proto->operand_ids()) {
if (related_ops.insert(id).second) {
worklist.push(id);
}
}
- for (int64 called_id : instructions_[node].called_computation_ids()) {
+ for (int64 called_id : instr_proto->called_computation_ids()) {
related_calls.insert(called_id);
}
}
@@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// Add related ops to the computation.
for (int64 id : related_ops) {
auto* instr = entry.add_instructions();
- *instr = instructions_[id];
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
+ LookUpInstructionByHandle(id));
+ *instr = *instr_src;
// Ensures that the instruction names are unique among the graph.
const string& new_name =
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
@@ -2415,7 +2420,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
- const int64 handle = instructions_.size();
+ const int64 handle = GetUniqueId();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
@@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
*instr.mutable_sharding() = *sharding_;
}
- instructions_.push_back(instr);
+ handle_to_index_[handle] = instructions_.size();
+ instructions_.push_back(std::move(instr));
XlaOp op(handle, this);
return op;
@@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
op.handle(), op.builder_->name(), this->name());
}
- if (op.handle() >= instructions_.size() || op.handle() < 0) {
- return InvalidArgument("no XlaOp value %d", op.handle());
+ return LookUpInstructionByHandle(op.handle());
+}
+
+StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
+ int64 handle) const {
+ auto it = handle_to_index_.find(handle);
+ if (it == handle_to_index_.end()) {
+ return InvalidArgument("No XlaOp with handle %d", handle);
}
- return &instructions_[op.handle()];
+ return &instructions_[it->second];
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 833eafcf85..d0c59fa6f2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -955,6 +956,8 @@ class XlaBuilder {
HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+ StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
+ int64 handle) const;
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@@ -1024,6 +1027,10 @@ class XlaBuilder {
// The instructions of this computation.
std::vector<HloInstructionProto> instructions_;
+ // A map from XlaOp::Handle to the index in the instructions_ vector where the
+ // instruction is held.
+ tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
// that XlaComputation.
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index d2bea9c8da..fc259a6ca2 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1963,6 +1963,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
+ ":hlo_memory_scheduler",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 233d2199d1..8c6903d766 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -562,9 +562,11 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ auto computation = absl::WrapUnique(
+ new HloComputation(proto.name(), parameter_count, &instructions, root,
+ /*fusion_instruction=*/nullptr));
+ computation->unique_id_ = proto.id();
+ return std::move(computation);
}
void HloComputation::FuseInstructionsInto(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 85fa3ce964..e905f2983a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
+ instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cfe906d9c5..b3949f3a6d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names) {
+ bool uniquify_identifiers) {
if (is_entry) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
@@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
}
}
- if (uniquify_names) {
+ if (uniquify_identifiers) {
computation->UniquifyName(&computation_name_uniquer_);
for (auto* instruction : computation->instructions()) {
instruction->UniquifyName(&instruction_name_uniquer_);
}
+
+ // Pick unique IDs for each instruction.
+ for (auto* instruction : computation->instructions()) {
+ instruction->SetUniqueId(NewUniqueInstructionId());
+ }
+ // Set unique id to this computation.
+ CHECK_NE(computation->root_instruction()->unique_id(), -1)
+ << "Root has no valid id: " << computation->ToString();
+ computation->SetUniqueId(computation->root_instruction()->unique_id());
} else {
// Don't uniquify the names of the computation or instruction, but we must
// run the names through the uniquifiers to prevent future name collisions
- // for computations and instructions created later.
+ // for computations and instructions created later. Also, set the
+ // next_unique_id_ to the one greater than the max unique id of any
+ // instruction (or the computation) to avoid ID collisions.
computation_name_uniquer_.GetUniqueName(computation->name());
for (auto* instruction : computation->instructions()) {
instruction_name_uniquer_.GetUniqueName(instruction->name());
+ next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
+ }
+ if (next_unique_id_ < computation->unique_id() + 1) {
+ next_unique_id_ = computation->unique_id() + 1;
}
}
- // Pick unique IDs for each instruction.
- for (auto* instruction : computation->instructions()) {
- instruction->SetUniqueId(NewUniqueInstructionId());
- }
- // Set unique id to this computation.
- CHECK_NE(computation->root_instruction()->unique_id(), -1)
- << "Root has no valid id: " << computation->ToString();
- computation->SetUniqueId(computation->root_instruction()->unique_id());
-
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
@@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
void HloModule::ReplaceComputations(
@@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ VLOG(2) << "CreateFromProto()";
+ XLA_VLOG_LINES(2, proto.DebugString());
+
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
module->AddComputationInternal(std::move(computation), is_entry,
- /*uniquify_names=*/false);
+ /*uniquify_identifiers=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
- // Because we didn't uniquify the names, double-check that the instruction and
- // computation names are unique from the proto.
+ // Because we didn't uniquify the names or the ids, double-check that the
+ // instruction and computation names and ids are unique from the proto.
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
+ tensorflow::gtl::FlatSet<int> computation_ids;
+ tensorflow::gtl::FlatSet<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());
+
+ TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
+ << "Computation id is not unique: " << computation->unique_id();
+ computation_ids.insert(computation->unique_id());
for (HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
<< "Instruction name is not unique: " << instruction->name();
instruction_names.insert(instruction->name());
+
+ TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
+ << "Instruction id is not unique: " << instruction->unique_id();
+ instruction_ids.insert(instruction->unique_id());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 26fd1b2438..3bc2d13781 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -253,7 +253,7 @@ class HloModule {
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names);
+ bool uniquify_identifiers);
const string name_;
HloModuleConfig config_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 400bd4d947..6243943420 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -253,6 +254,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
+ // Verify that serializing then deserializing an HLO proto preserves the
+ // unique IDs of the instruction and module.
+ const string text =
+ R"(HloModule ReduceR3ToR2_module
+
+add_F32.v3 {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY ReduceR3ToR2.v3 {
+ input = f32[8,16,256]{2,1,0} parameter(0)
+ constant = f32[] constant(0)
+ ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+
+ // Perform various transformations on the graph:
+ //
+ // * clone the reduction function
+ // * replace use of reduction function with the clone.
+ // * add a random instruction to the entry computation.
+ //
+ // This will create instruction and computation IDs which are interesting:
+ // not consecutive and not densely packed.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* root = entry->root_instruction();
+ HloComputation* reduction = root->to_apply();
+ HloComputation* reduction_clone =
+ module->AddEmbeddedComputation(reduction->Clone());
+ root->set_to_apply(reduction_clone);
+ TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
+ HloInstruction* negate = entry->AddInstruction(
+ HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
+ entry->set_root_instruction(negate);
+
+ // Schedule the transformed module, this verifies that the serialized schedule
+ // is robust against non-consecutive IDs as well (b/114712358).
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ HloMemoryScheduler scheduler(size_fn);
+ TF_ASSERT_OK(scheduler.Run(module.get()).status());
+ ASSERT_TRUE(module->has_schedule());
+
+ // Serialize and deserialize and verify that the instruction and computations
+ // unique ids are the same.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+
+ // The module IDs should *not* be the same because module ids must be globally
+ // unique.
+ EXPECT_NE(module->unique_id(), module_copy->unique_id());
+
+ // Verify that the computations and instructions all have the same unique id.
+ auto computation_copy_it = module_copy->computations().begin();
+ for (const HloComputation* computation_orig : module->computations()) {
+ const HloComputation* computation_copy = *computation_copy_it++;
+ EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original computation %s != ID of deserialized "
+ "computation %s: %d != %d",
+ computation_orig->name(), computation_copy->name(),
+ computation_orig->unique_id(), computation_copy->unique_id());
+
+ auto instruction_copy_it = computation_copy->instructions().begin();
+ for (const HloInstruction* instruction_orig :
+ computation_orig->instructions()) {
+ const HloInstruction* instruction_copy = *instruction_copy_it++;
+ EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original instruction %s != ID of deserialized "
+ "instruction %s: %d != %d",
+ instruction_orig->name(), instruction_copy->name(),
+ instruction_orig->unique_id(), instruction_copy->unique_id());
+ }
+ }
+
+ // Verify that the next unique ID which the module would have handed out is
+ // greater than the unique id of any instruction.
+ int next_id = module_copy->NewUniqueInstructionId();
+ for (const HloComputation* computation : module_copy->computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ EXPECT_GT(next_id, instruction->unique_id());
+ }
+ }
+}
+
} // namespace
} // namespace xla