diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-09 16:40:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-09 16:42:52 -0700 |
commit | c07b719ab030c46f19c8e5cdd92730eaec38a8fb (patch) | |
tree | ac03b54d62c68ce528027e7e88d912618b8dac0e /tensorflow/compiler/xla/service/hlo_computation.cc | |
parent | b348209171a2fac38def122d2ee43bd2fc3d9b1d (diff) |
[XLA] Make hlo deserialization stable for HloModule by sorting by ids when creating from proto.
Also, delete the HloModule parameter HloInstruction::CreateFromProto, it's not used anywhere.
Also, in ToProto, set sharding to proto if there is sharding.
PiperOrigin-RevId: 196049173
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 17e43c3cb8..05dceb1dc0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -407,27 +407,37 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr<std::unique_ptr<HloComputation>> HloComputation::CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) { - std::vector<std::unique_ptr<HloInstruction>> instructions; tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map; + tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id; + std::vector<std::unique_ptr<HloInstruction>> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { TF_ASSIGN_OR_RETURN( std::unique_ptr<HloInstruction> instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + HloInstruction::CreateFromProto(instruction_proto, instruction_map, + computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); instruction_map[instruction_proto.id()] = instruction.get(); + to_proto_id[instruction.get()] = instruction_proto.id(); instructions.push_back(std::move(instruction)); } TF_RET_CHECK(proto.root_id() != -1); TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); HloInstruction* root = instruction_map.at(proto.root_id()); + + // Sort the instructions in the proto id's order. + std::sort(instructions.begin(), instructions.end(), + [&](const std::unique_ptr<HloInstruction>& a, + const std::unique_ptr<HloInstruction>& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + return WrapUnique(new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); |