aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc25
1 files changed, 8 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c030ceb72f..2d07784619 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -58,16 +58,16 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
CHECK_NE(nullptr, root);
return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
- root, is_fusion_computation_));
+ root, fusion_instruction_));
}
HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction, bool is_fusion_computation)
+ HloInstruction* root_instruction, HloInstruction* fusion_instruction)
: name_(name),
root_instruction_(root_instruction),
- is_fusion_computation_(is_fusion_computation) {
+ fusion_instruction_(fusion_instruction) {
param_instructions_.resize(parameter_count, nullptr);
bool root_found = false;
for (auto& instruction : *instructions) {
@@ -112,11 +112,8 @@ HloInstruction* HloComputation::AddInstructionInternal(
HloInstruction* HloComputation::AddParameter(
std::unique_ptr<HloInstruction> instruction) {
CHECK(instruction->opcode() == HloOpcode::kParameter);
- CHECK(is_fusion_computation_);
- CHECK(root_instruction_->fusion_instruction() != nullptr);
- instruction->SetParentFusion(root_instruction_->fusion_instruction());
- CHECK(root_instruction_->fusion_instruction()->operand_count() ==
- param_instructions_.size());
+ CHECK(IsFusionComputation());
+ CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
instruction->set_parent(this);
param_instructions_.push_back(instruction.get());
AddInstructionInternal(std::move(instruction));
@@ -126,8 +123,7 @@ HloInstruction* HloComputation::AddParameter(
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
- CHECK(is_fusion_computation_);
- CHECK(root_instruction_->fusion_instruction() != nullptr);
+ CHECK(IsFusionComputation());
HloInstruction* param_instruction = param_instructions_[param_no];
auto param_instruction_iterator = param_instructions_.begin() + param_no;
param_instructions_.erase(param_instruction_iterator);
@@ -155,7 +151,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), param_name));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
- new_instr->SetParentFusion(root_instruction_->fusion_instruction());
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
param_no++;
@@ -166,10 +161,6 @@ Status HloComputation::RemoveParameter(int64 param_no) {
void HloComputation::Reparent(HloInstruction* instruction) {
instruction->set_parent(this);
- if (is_fusion_computation_ && instruction != root_instruction_) {
- CHECK(root_instruction_->fusion_instruction() != nullptr);
- instruction->SetParentFusion(root_instruction_->fusion_instruction());
- }
}
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
@@ -182,7 +173,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) {
}
if (instruction->opcode() == HloOpcode::kParameter &&
- !is_fusion_computation_) {
+ !IsFusionComputation()) {
return false;
}
@@ -267,7 +258,7 @@ void HloComputation::set_root_instruction(
HloInstruction* new_root_instruction) {
// The shape of the root (ignoring layout) is an invariant of the computation
// for non-fusion cases.
- if (!is_fusion_computation_) {
+ if (!IsFusionComputation()) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
<< new_root_instruction->shape().ShortDebugString()