aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-09 16:40:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 16:42:52 -0700
commitc07b719ab030c46f19c8e5cdd92730eaec38a8fb (patch)
treeac03b54d62c68ce528027e7e88d912618b8dac0e /tensorflow/compiler/xla/service/hlo_computation.cc
parentb348209171a2fac38def122d2ee43bd2fc3d9b1d (diff)
[XLA] Make hlo deserialization stable for HloModule by sorting by ids when creating from proto.
Also, delete the HloModule parameter HloInstruction::CreateFromProto, it's not used anywhere. Also, in ToProto, set sharding to proto if there is sharding. PiperOrigin-RevId: 196049173
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 17e43c3cb8..05dceb1dc0 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -407,27 +407,37 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
- HloModule* module, const HloComputationProto& proto,
+ const HloComputationProto& proto,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
- std::vector<std::unique_ptr<HloInstruction>> instructions;
tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id;
+ std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloInstruction> instruction,
- HloInstruction::CreateFromProto(module, instruction_proto,
- instruction_map, computation_map));
+ HloInstruction::CreateFromProto(instruction_proto, instruction_map,
+ computation_map));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
instruction_map[instruction_proto.id()] = instruction.get();
+ to_proto_id[instruction.get()] = instruction_proto.id();
instructions.push_back(std::move(instruction));
}
TF_RET_CHECK(proto.root_id() != -1);
TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
HloInstruction* root = instruction_map.at(proto.root_id());
+
+ // Sort the instructions in the proto id's order.
+ std::sort(instructions.begin(), instructions.end(),
+ [&](const std::unique_ptr<HloInstruction>& a,
+ const std::unique_ptr<HloInstruction>& b) {
+ return to_proto_id[a.get()] < to_proto_id[b.get()];
+ });
+
return WrapUnique(new HloComputation(proto.name(), parameter_count,
&instructions, root,
/*fusion_instruction=*/nullptr));