aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-10-13 16:24:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 16:29:20 -0700
commit5dd569cf026bae92330a194c8f2895d0f48149d9 (patch)
tree96dbce8d2992fb1f14aa0a1265904eb57eaf2273 /tensorflow/compiler/xla/service/hlo_computation.cc
parentd426d3029727785676d1a7fbb7973a3a6ceb4842 (diff)
Make the HLO proto representation (hlo.proto) full fidelity. Hlo modules can be serialized to HLO protos and deserialized without any information loss.
As part of this change, a bug is fixed in NameUniquer. Previously, passing names with numeric suffixes could result in name collisions. PiperOrigin-RevId: 172161360
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 444104d88f..9b3104eaac 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -400,9 +400,38 @@ HloComputationProto HloComputation::ToProto() const {
HloInstructionProto instruction_proto = instruction->ToProto();
proto.add_instructions()->Swap(&instruction_proto);
}
+ proto.set_root_name(root_instruction()->name());
return proto;
}
+/* static */ StatusOr<std::unique_ptr<HloComputation>>
+HloComputation::CreateFromProto(
+ HloModule* module, const HloComputationProto& proto,
+ tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map,
+ HloInstruction* fusion_instruction) {
+ std::vector<std::unique_ptr<HloInstruction>> instructions;
+ tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
+ int64 parameter_count = 0;
+ for (const HloInstructionProto& instruction_proto : proto.instructions()) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloInstruction> instruction,
+ HloInstruction::CreateFromProto(module, instruction_proto,
+ instruction_map, computation_map));
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ parameter_count++;
+ }
+ TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name()));
+ instruction_map[instruction->name()] = instruction.get();
+ instructions.push_back(std::move(instruction));
+ }
+
+ TF_RET_CHECK(!proto.root_name().empty());
+ TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name()));
+ HloInstruction* root = instruction_map.at(proto.root_name());
+ return WrapUnique(new HloComputation(
+ proto.name(), parameter_count, &instructions, root, fusion_instruction));
+}
+
void HloComputation::FuseInstructionsInto(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
HloInstruction* fusion_instruction) {