aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-17 13:39:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-17 14:59:42 -0700
commit33fd4134234170745f989e2cdd73c8ca8709d926 (patch)
treec7fad80e48ba8cb8da246583669266c143a75543 /tensorflow/compiler/xla/service/hlo_computation.cc
parentcc45456e4ad0eff16127d1727d0cf48afb71ca0e (diff)
[XLA] Represent fusion instructions as a HloComputation
Using a HloComputation to represent the HloInstructions inside a fusion instruction. All the interfaces are kept the same except for the parent field of the fusion instruction. It now points to the newly created HloComputation rather the enclosing computation for the fusion instruction. Change: 153390245
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc73
1 files changed, 52 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 15de24fffd..655546d715 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -52,16 +52,17 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
root_instruction ? root_instruction : last_added_instruction_;
CHECK_NE(nullptr, root);
- return WrapUnique(
- new HloComputation(name_, parameter_count, &instructions_, root));
+ return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
+ root, is_fusion_computation_));
}
HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction)
+ HloInstruction* root_instruction, bool is_fusion_computation)
: name_(name),
root_instruction_(root_instruction),
+ is_fusion_computation_(is_fusion_computation),
instruction_name_uniquer_(/*separator=*/".") {
param_instructions_.resize(parameter_count, nullptr);
bool root_found = false;
@@ -99,19 +100,54 @@ HloInstruction* HloComputation::AddInstructionInternal(
return pinst;
}
-void HloComputation::Reparent(HloInstruction* instruction) {
+HloInstruction* HloComputation::AddParameter(
+ std::unique_ptr<HloInstruction> instruction) {
+ CHECK(instruction->opcode() == HloOpcode::kParameter);
+ CHECK(is_fusion_computation_);
+ CHECK(root_instruction_->fusion_instruction() != nullptr);
+ instruction->SetParentFusion(root_instruction_->fusion_instruction());
+ CHECK(root_instruction_->fusion_instruction()->operand_count() ==
+ param_instructions_.size());
instruction->set_parent(this);
- if (instruction->opcode() == HloOpcode::kFusion) {
- for (auto& i : instruction->fused_instructions()) {
- Reparent(i.get());
- }
+ param_instructions_.push_back(instruction.get());
+ AddInstructionInternal(std::move(instruction));
+ return instructions_.back().get();
+}
+
+Status HloComputation::RemoveParameter(int64 param_no) {
+ CHECK_GE(param_no, 0);
+ CHECK_LT(param_no, param_instructions_.size());
+ CHECK(is_fusion_computation_);
+ CHECK(root_instruction_->fusion_instruction() != nullptr);
+ HloInstruction* param_instruction = param_instructions_[param_no];
+ auto param_instruction_iterator = param_instructions_.begin() + param_no;
+ param_instructions_.erase(param_instruction_iterator);
+ // Throw removed fused parameter instruction away.
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+
+ while (param_no < param_instructions_.size()) {
+ param_instruction = param_instructions_[param_no];
+ HloInstruction* new_instr =
+ AddInstructionInternal(HloInstruction::CreateParameter(
+ param_no, param_instruction->shape(), param_instruction->name()));
+ TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
+ new_instr->SetParentFusion(root_instruction_->fusion_instruction());
+ param_instructions_[param_no] = new_instr;
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ param_no++;
}
+
+ return Status::OK();
}
-/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) {
- return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv ||
- opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace ||
- opcode == HloOpcode::kOutfeed);
+void HloComputation::Reparent(HloInstruction* instruction) {
+ instruction->set_parent(this);
+}
+
+bool HloComputation::IsRemovable(const HloOpcode& opcode) {
+ return !((opcode == HloOpcode::kParameter && !is_fusion_computation_) ||
+ opcode == HloOpcode::kRecv || opcode == HloOpcode::kSend ||
+ opcode == HloOpcode::kTrace || opcode == HloOpcode::kOutfeed);
}
Status HloComputation::RemoveInstructionAndUnusedOperands(
@@ -119,7 +155,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
TF_RET_CHECK(root_instruction() != instruction);
TF_RET_CHECK(instruction->user_count() == 0);
- TF_RET_CHECK(HloComputation::IsRemovable(instruction->opcode()));
+ TF_RET_CHECK(IsRemovable(instruction->opcode()));
std::unordered_set<HloInstruction*> removed;
std::queue<HloInstruction*> worklist;
worklist.push(instruction);
@@ -128,8 +164,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
worklist.pop();
if (removed.count(item) != 0 || item->user_count() != 0 ||
- item == root_instruction() ||
- !HloComputation::IsRemovable(item->opcode())) {
+ item == root_instruction() || !IsRemovable(item->opcode())) {
continue;
}
for (int i = 0; i < item->operand_count(); ++i) {
@@ -302,12 +337,8 @@ string HloComputation::ToString() const {
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
s << " " << instruction->ToString() << "\n";
if (instruction->opcode() == HloOpcode::kFusion) {
- tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
- auto fused_instructions = InstructionPostOrderer::GetOrder(
- instruction->fused_expression_root(), &added_instructions);
- for (const auto& fused_instruction : fused_instructions) {
- s << " " << fused_instruction->ToString() << "\n";
- }
+ s << " " << instruction->fused_instructions_computation()->ToString()
+ << "\n";
}
}
s << "}";