aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-20 16:13:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 16:17:13 -0700
commit49ee96a60bea1b595cff3cb550cfc8d2ade5ed8b (patch)
treeaa71f2881ff280bb608ee9d0ec9ee05fc7c1a93c /tensorflow/compiler/xla/service/hlo_computation.cc
parent0bd851b38810540034069d92a2f76a026429bced (diff)
[XLA] Use IDs instead of names to represent the edges of HLO graph in hlo.proto.
PiperOrigin-RevId: 189831057
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc27
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 4e852190a8..6f983d0b95 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -65,6 +65,7 @@ HloComputation::HloComputation(
std::vector<std::unique_ptr<HloInstruction>>* instructions,
HloInstruction* root_instruction, HloInstruction* fusion_instruction)
: name_(name),
+ unique_id_(-1),
root_instruction_(root_instruction),
fusion_instruction_(fusion_instruction) {
param_instructions_.resize(parameter_count, nullptr);
@@ -101,7 +102,7 @@ HloInstruction* HloComputation::AddInstructionInternal(
instruction->UniquifyName(&parent()->instruction_name_uniquer());
instruction->SetUniqueId(parent()->NewUniqueInstructionId());
}
- Reparent(instruction.get());
+ instruction->set_parent(this);
HloInstruction* pinst = instruction.get();
instruction_iterators_[pinst] =
instructions_.insert(instructions_.end(), std::move(instruction));
@@ -158,10 +159,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
return Status::OK();
}
-void HloComputation::Reparent(HloInstruction* instruction) {
- instruction->set_parent(this);
-}
-
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
// If the instruction has control predecessors or successors then we cannot
// remove the instruction without violating ordering constraints (added, for
@@ -393,12 +390,16 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
HloComputationProto HloComputation::ToProto() const {
HloComputationProto proto;
+ CHECK(unique_id_ != -1)
+ << "This computation does not have a valid id. Please make sure the "
+ "computation is inside a module before dumping it.";
+ proto.set_id(unique_id_);
proto.set_name(name_);
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
HloInstructionProto instruction_proto = instruction->ToProto();
proto.add_instructions()->Swap(&instruction_proto);
}
- proto.set_root_name(root_instruction()->name());
+ proto.set_root_id(root_instruction()->unique_id());
*proto.mutable_program_shape() = ComputeProgramShape();
return proto;
}
@@ -406,9 +407,9 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
HloModule* module, const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
+ const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
std::vector<std::unique_ptr<HloInstruction>> instructions;
- tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
+ tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(
@@ -418,14 +419,14 @@ HloComputation::CreateFromProto(
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
- TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name()));
- instruction_map[instruction->name()] = instruction.get();
+ TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
+ instruction_map[instruction_proto.id()] = instruction.get();
instructions.push_back(std::move(instruction));
}
- TF_RET_CHECK(!proto.root_name().empty());
- TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name()));
- HloInstruction* root = instruction_map.at(proto.root_name());
+ TF_RET_CHECK(proto.root_id() != -1);
+ TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
+ HloInstruction* root = instruction_map.at(proto.root_id());
return WrapUnique(new HloComputation(proto.name(), parameter_count,
&instructions, root,
/*fusion_instruction=*/nullptr));