diff options
author | 2017-10-13 16:24:56 -0700 | |
---|---|---|
committer | 2017-10-13 16:29:20 -0700 | |
commit | 5dd569cf026bae92330a194c8f2895d0f48149d9 (patch) | |
tree | 96dbce8d2992fb1f14aa0a1265904eb57eaf2273 /tensorflow/compiler/xla/service/hlo_computation.cc | |
parent | d426d3029727785676d1a7fbb7973a3a6ceb4842 (diff) |
Make the HLO proto representation (hlo.proto) full fidelity. Hlo modules can be serialized to HLO protos and deserialized without any information loss.
As part of this change, a bug is fixed in NameUniquer. Previously, passing names with numeric suffixes could result in name collisions.
PiperOrigin-RevId: 172161360
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 444104d88f..9b3104eaac 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -400,9 +400,38 @@ HloComputationProto HloComputation::ToProto() const { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } + proto.set_root_name(root_instruction()->name()); return proto; } +/* static */ StatusOr<std::unique_ptr<HloComputation>> +HloComputation::CreateFromProto( + HloModule* module, const HloComputationProto& proto, + tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map, + HloInstruction* fusion_instruction) { + std::vector<std::unique_ptr<HloInstruction>> instructions; + tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map; + int64 parameter_count = 0; + for (const HloInstructionProto& instruction_proto : proto.instructions()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr<HloInstruction> instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_count++; + } + TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); + instruction_map[instruction->name()] = instruction.get(); + instructions.push_back(std::move(instruction)); + } + + TF_RET_CHECK(!proto.root_name().empty()); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); + HloInstruction* root = instruction_map.at(proto.root_name()); + return WrapUnique(new HloComputation( + proto.name(), parameter_count, &instructions, root, fusion_instruction)); +} + void HloComputation::FuseInstructionsInto( tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse, HloInstruction* fusion_instruction) { |