aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-17 13:39:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-17 14:59:42 -0700
commit33fd4134234170745f989e2cdd73c8ca8709d926 (patch)
treec7fad80e48ba8cb8da246583669266c143a75543
parentcc45456e4ad0eff16127d1727d0cf48afb71ca0e (diff)
[XLA] Represent fusion instructions as a HloComputation
Using a HloComputation to represent the HloInstructions inside a fusion instruction. All the interfaces are kept the same except for the parent field of the fusion instruction. It now points to the newly created HloComputation rather the enclosing computation for the fusion instruction. Change: 153390245
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc73
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h34
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc126
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h32
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc3
7 files changed, 192 insertions, 83 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 15de24fffd..655546d715 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -52,16 +52,17 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build(
root_instruction ? root_instruction : last_added_instruction_;
CHECK_NE(nullptr, root);
- return WrapUnique(
- new HloComputation(name_, parameter_count, &instructions_, root));
+ return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
+ root, is_fusion_computation_));
}
HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction)
+ HloInstruction* root_instruction, bool is_fusion_computation)
: name_(name),
root_instruction_(root_instruction),
+ is_fusion_computation_(is_fusion_computation),
instruction_name_uniquer_(/*separator=*/".") {
param_instructions_.resize(parameter_count, nullptr);
bool root_found = false;
@@ -99,19 +100,54 @@ HloInstruction* HloComputation::AddInstructionInternal(
return pinst;
}
-void HloComputation::Reparent(HloInstruction* instruction) {
+HloInstruction* HloComputation::AddParameter(
+ std::unique_ptr<HloInstruction> instruction) {
+ CHECK(instruction->opcode() == HloOpcode::kParameter);
+ CHECK(is_fusion_computation_);
+ CHECK(root_instruction_->fusion_instruction() != nullptr);
+ instruction->SetParentFusion(root_instruction_->fusion_instruction());
+ CHECK(root_instruction_->fusion_instruction()->operand_count() ==
+ param_instructions_.size());
instruction->set_parent(this);
- if (instruction->opcode() == HloOpcode::kFusion) {
- for (auto& i : instruction->fused_instructions()) {
- Reparent(i.get());
- }
+ param_instructions_.push_back(instruction.get());
+ AddInstructionInternal(std::move(instruction));
+ return instructions_.back().get();
+}
+
+Status HloComputation::RemoveParameter(int64 param_no) {
+ CHECK_GE(param_no, 0);
+ CHECK_LT(param_no, param_instructions_.size());
+ CHECK(is_fusion_computation_);
+ CHECK(root_instruction_->fusion_instruction() != nullptr);
+ HloInstruction* param_instruction = param_instructions_[param_no];
+ auto param_instruction_iterator = param_instructions_.begin() + param_no;
+ param_instructions_.erase(param_instruction_iterator);
+ // Throw removed fused parameter instruction away.
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+
+ while (param_no < param_instructions_.size()) {
+ param_instruction = param_instructions_[param_no];
+ HloInstruction* new_instr =
+ AddInstructionInternal(HloInstruction::CreateParameter(
+ param_no, param_instruction->shape(), param_instruction->name()));
+ TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
+ new_instr->SetParentFusion(root_instruction_->fusion_instruction());
+ param_instructions_[param_no] = new_instr;
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ param_no++;
}
+
+ return Status::OK();
}
-/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) {
- return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv ||
- opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace ||
- opcode == HloOpcode::kOutfeed);
+void HloComputation::Reparent(HloInstruction* instruction) {
+ instruction->set_parent(this);
+}
+
+bool HloComputation::IsRemovable(const HloOpcode& opcode) {
+ return !((opcode == HloOpcode::kParameter && !is_fusion_computation_) ||
+ opcode == HloOpcode::kRecv || opcode == HloOpcode::kSend ||
+ opcode == HloOpcode::kTrace || opcode == HloOpcode::kOutfeed);
}
Status HloComputation::RemoveInstructionAndUnusedOperands(
@@ -119,7 +155,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
TF_RET_CHECK(root_instruction() != instruction);
TF_RET_CHECK(instruction->user_count() == 0);
- TF_RET_CHECK(HloComputation::IsRemovable(instruction->opcode()));
+ TF_RET_CHECK(IsRemovable(instruction->opcode()));
std::unordered_set<HloInstruction*> removed;
std::queue<HloInstruction*> worklist;
worklist.push(instruction);
@@ -128,8 +164,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
worklist.pop();
if (removed.count(item) != 0 || item->user_count() != 0 ||
- item == root_instruction() ||
- !HloComputation::IsRemovable(item->opcode())) {
+ item == root_instruction() || !IsRemovable(item->opcode())) {
continue;
}
for (int i = 0; i < item->operand_count(); ++i) {
@@ -302,12 +337,8 @@ string HloComputation::ToString() const {
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
s << " " << instruction->ToString() << "\n";
if (instruction->opcode() == HloOpcode::kFusion) {
- tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
- auto fused_instructions = InstructionPostOrderer::GetOrder(
- instruction->fused_expression_root(), &added_instructions);
- for (const auto& fused_instruction : fused_instructions) {
- s << " " << fused_instruction->ToString() << "\n";
- }
+ s << " " << instruction->fused_instructions_computation()->ToString()
+ << "\n";
}
}
s << "}";
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 82934595e1..1b32b9fe56 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -54,8 +54,10 @@ class HloComputation {
// Builder class for HloComputation.
class Builder {
public:
- explicit Builder(const string& name)
- : name_(name), last_added_instruction_(nullptr) {}
+ explicit Builder(const string& name, bool is_fusion_computation = false)
+ : name_(name),
+ last_added_instruction_(nullptr),
+ is_fusion_computation_(is_fusion_computation) {}
// Build and return an HloComputation. The parameter root_instruction
// specifies the already-added instruction to use as the root. If
@@ -74,6 +76,7 @@ class HloComputation {
private:
const string name_;
HloInstruction* last_added_instruction_;
+ bool is_fusion_computation_;
std::vector<std::unique_ptr<HloInstruction>> instructions_;
};
@@ -81,6 +84,16 @@ class HloComputation {
// the instruction.
HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
+ // Remove the param_no'th parameter from the computation.
+ // Note this is only applicatable to the computation for the fusion
+ // instruction.
+ Status RemoveParameter(int64 param_no);
+
+ // Add new parameter instruction to the computation.
+ // This should be a new parameter. Instruction will be appended to parameters
+ // and inserted to the instruction list.
+ HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
+
// Remove an instruction from the computation. The instruction must have no
// users. Instruction is deallocated with this call.
Status RemoveInstruction(HloInstruction* instruction);
@@ -226,14 +239,18 @@ class HloComputation {
// Returns true if instructions of the given opcode can be removed from the
// computation. Instructions such as parameters and send/receive instructions
// cannot be removed without violating invariants of the HLO computation or
- // module.
- static bool IsRemovable(const HloOpcode& opcode);
+ // module with the exception of fusion computation.
+ // A parameter instruction is removable for a fusion computation.
+ bool IsRemovable(const HloOpcode& opcode);
+
+ // Returns if this computation is a fusion computation.
+ bool IsFusionComputation() const { return is_fusion_computation_; }
private:
explicit HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction);
+ HloInstruction* root_instruction, bool is_fusion_computation = false);
// Internal helper for adding instructions.
HloInstruction* AddInstructionInternal(
@@ -241,10 +258,6 @@ class HloComputation {
// Helper for setting the parent of instructions that are added to this
// computation.
- //
- // Because we clone HLO instructions without knowing what computation they're
- // destined to be added to, this is required to appropriate set the parent on
- // fused instruction sequences.
void Reparent(HloInstruction* instruction);
// Fuses HLOs in instructions_to_fuse into fusion_instruction.
@@ -264,6 +277,9 @@ class HloComputation {
string name_;
HloInstruction* root_instruction_;
+ // A tag shows if this is a fusion computation.
+ bool is_fusion_computation_;
+
// Module containing this computation.
HloModule* parent_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index fdfbbf8baf..dbf5bf28c8 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -52,7 +52,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
for (auto& instruction : computation->instructions()) {
if (instruction->user_count() == 0 &&
live_instructions.count(instruction.get()) == 0 &&
- HloComputation::IsRemovable(instruction->opcode())) {
+ computation->IsRemovable(instruction->opcode())) {
dead_roots.push_back(instruction.get());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 979afa4369..d15b8236bb 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -28,6 +28,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -498,14 +500,28 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(instruction_to_fuse->IsFusable());
- bool new_fusion_instruction = fused_instructions_.empty();
- fused_instructions_.emplace_back(instruction_to_fuse->Clone());
- HloInstruction* clone = fused_instructions_.back().get();
- clone->parent_fusion_instruction_ = this;
-
- if (new_fusion_instruction) {
- fused_root_ = clone;
+ HloInstruction* clone = nullptr;
+ if (fused_instructions_computation_ == nullptr) {
+ // New fusion instruction.
+ string computation_name;
+ HloModule* module = GetModule();
+ if (module) {
+ computation_name = module->GetUniqueCompuationName(
+ instruction_to_fuse->name() + ".fusion");
+ } else {
+ computation_name = instruction_to_fuse->name() + ".fusion";
+ }
+ auto builder = HloComputation::Builder(computation_name, true);
+ builder.AddInstruction(instruction_to_fuse->Clone());
+ fused_instructions_computation_ = builder.Build();
+ clone = fused_expression_root();
+ clone->parent_fusion_instruction_ = this;
} else {
+ CHECK(fused_instructions_computation_ != nullptr &&
+ fused_instructions_computation_->IsFusionComputation());
+ clone = fused_instructions_computation_->AddInstruction(
+ instruction_to_fuse->Clone());
+ clone->parent_fusion_instruction_ = this;
// instruction_to_fuse is necessarily an operand of the fusion instruction.
// After fusion this will no longer be the case. Remove the operand from the
// operand list and remove its corresponding fused parameter
@@ -513,6 +529,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// consistent with their index in the fused_parameter_ vector.
CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) !=
operands_.end());
+ const std::vector<HloInstruction*>& fused_parameters_ =
+ fused_instructions_computation_->parameter_instructions();
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
if (instruction_to_fuse == operands_[operand_num]) {
// replace the fused parameter instruction's uses with the clone.
@@ -521,22 +539,9 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// Remove the corresponding fused parameter and operand from their
// respective vectors.
- fused_parameters_.erase(fused_parameters_.begin() + operand_num);
+ TF_CHECK_OK(
+ fused_instructions_computation_->RemoveParameter(operand_num));
operands_.erase(operands_.begin() + operand_num);
-
- // Renumber fused parameter numbers to match the vector index.
- while (operand_num < fused_parameters_.size()) {
- fused_parameters_[operand_num]->parameter_number_ = operand_num;
- operand_num++;
- }
- // Throw removed fused parameter instruction away.
- auto inst_it =
- std::find_if(fused_instructions_.begin(), fused_instructions_.end(),
- [=](const std::unique_ptr<HloInstruction>& inst) {
- return inst.get() == fused_parameter;
- });
- CHECK(inst_it != fused_instructions_.end());
- fused_instructions_.erase(inst_it);
break;
}
}
@@ -545,6 +550,10 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
instruction_to_fuse->RemoveUser(this);
}
+ // Reread the parameters in the computation.
+ const std::vector<HloInstruction*>& fused_parameters_ =
+ fused_instructions_computation_->parameter_instructions();
+
// Add each operand of the clone as an operand of the fusion instruction. A
// complication is that some clone operands may already be operands of the
// fusion instruction.
@@ -570,13 +579,10 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
std::unique_ptr<HloInstruction> param_instruction =
CreateParameter(param_no, operand->shape(), "fusion_param");
- param_instruction->set_parent(parent());
param_instruction->parent_fusion_instruction_ = this;
- fused_parameters_.push_back(param_instruction.get());
- fused_instructions_.push_back(std::move(param_instruction));
+ fused_param = fused_instructions_computation_->AddParameter(
+ std::move(param_instruction));
AppendOperand(operand);
-
- fused_param = fused_instructions_.back().get();
}
TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
}
@@ -599,18 +605,25 @@ RandomDistribution HloInstruction::random_distribution() const {
void HloInstruction::CheckFusionInstruction() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
+ CHECK(fused_instructions_computation_ != nullptr &&
+ fused_instructions_computation_->IsFusionComputation());
+ const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
+ fused_instructions_computation_->instructions();
// All instructions owned by this fusion instruction must be fused, and the
// parent fusion instruction of the fused instructions must be 'this'.
for (auto& instruction : fused_instructions_) {
CHECK(instruction->IsFused());
CHECK_EQ(this, instruction->fusion_instruction());
- CHECK_EQ(parent(), instruction->parent()) << instruction->ToString();
+ CHECK_EQ(fused_instructions_computation_.get(), instruction->parent())
+ << instruction->ToString();
}
// Fused root instruction and fused parameters must all be owned by the fusion
// instruction.
bool root_owned = false;
+ const std::vector<HloInstruction*>& fused_parameters_ = fused_parameters();
+ const HloInstruction* fused_root_ = fused_expression_root();
std::vector<bool> parameter_owned(fused_parameters_.size(), false);
for (auto& instruction : fused_instructions_) {
if (fused_root_ == instruction.get()) {
@@ -838,6 +851,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
}
}
+HloInstruction::~HloInstruction() {}
+
std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix) {
std::unique_ptr<HloInstruction> clone =
CloneWithNewOperands(shape_, operands_);
@@ -850,6 +865,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(parent() != nullptr);
+ CHECK(fused_instructions_computation_ != nullptr &&
+ fused_instructions_computation_->IsFusionComputation());
auto new_instruction =
WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
@@ -863,6 +880,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
// Create the list of fused parameters by mapping through the cloned,
// fused instructions.
std::vector<HloInstruction*> new_fused_parameters;
+ const std::vector<HloInstruction*>& fused_parameters_ =
+ fused_instructions_computation_->parameter_instructions();
+ const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
+ fused_instructions_computation_->instructions();
+
for (HloInstruction* old_fused_parameter : fused_parameters_) {
new_fused_instructions.push_back(old_fused_parameter->Clone());
HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
@@ -893,13 +915,19 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
new_fused_instruction->parent_fusion_instruction_ = new_instruction.get();
InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
}
+ new_instruction->fusion_kind_ = fusion_kind_;
+ auto computation_builder = HloComputation::Builder(
+ fused_instructions_computation_->name() + ".clone", true);
// We iterated the fusion instructions in reverse post order which means
// that we must reverse our new list of fusion instructions.
- std::reverse(new_fused_instructions.begin(), new_fused_instructions.end());
- new_instruction->fusion_kind_ = fusion_kind_;
- new_instruction->fused_instructions_ = std::move(new_fused_instructions);
- new_instruction->fused_parameters_ = std::move(new_fused_parameters);
- new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_);
+ for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
+ new_fused_instruction_iter != new_fused_instructions.rend();
+ ++new_fused_instruction_iter) {
+ computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
+ }
+ auto fused_root_ = fused_expression_root();
+ new_instruction->fused_instructions_computation_ =
+ computation_builder.Build(FindOrDie(old_to_new, fused_root_));
new_instruction->set_parent(parent());
new_instruction->CheckFusionInstruction();
return new_instruction;
@@ -1570,6 +1598,11 @@ bool HloInstruction::IsFusable() const {
}
}
+HloComputation* HloInstruction::fused_instructions_computation() const {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+ return fused_instructions_computation_.get();
+}
+
HloInstruction* HloInstruction::fusion_instruction() const {
CHECK(IsFused());
return parent_fusion_instruction_;
@@ -1577,25 +1610,32 @@ HloInstruction* HloInstruction::fusion_instruction() const {
HloInstruction* HloInstruction::fused_expression_root() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_root_;
+ CHECK(fused_instructions_computation_ != nullptr &&
+ fused_instructions_computation_->IsFusionComputation());
+ return fused_instructions_computation_->root_instruction();
}
HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK_GE(parameter_number, 0);
- CHECK_LT(parameter_number, fused_parameters_.size());
- return fused_parameters_[parameter_number];
+ CHECK(fused_instructions_computation_ != nullptr &&
+ fused_instructions_computation_->IsFusionComputation());
+ return fused_instructions_computation_->parameter_instruction(
+ parameter_number);
}
const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_parameters_;
+ CHECK(fused_instructions_computation_ != nullptr &&
+ fused_instructions_computation_->IsFusionComputation());
+ return fused_instructions_computation_->parameter_instructions();
}
const std::list<std::unique_ptr<HloInstruction>>&
HloInstruction::fused_instructions() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_instructions_;
+ CHECK(fused_instructions_computation_ != nullptr &&
+ fused_instructions_computation_->IsFusionComputation());
+ return fused_instructions_computation_->instructions();
}
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
@@ -2076,7 +2116,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
}
return cache[&hlo];
};
- return reuses_parameter_elements(*fused_root_);
+ return reuses_parameter_elements(*fused_expression_root());
}
default:
return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
@@ -2168,4 +2208,10 @@ bool HloInstruction::CouldBeBitcast() const {
}
}
+HloModule* HloInstruction::GetModule() const {
+ if (parent_) {
+ return parent_->parent();
+ }
+ return nullptr;
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 5f160892e9..4f62be5235 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -46,6 +46,7 @@ limitations under the License.
namespace xla {
class HloComputation;
+class HloModule;
// HLO instructions are the IR used by the high-level compiler.
class HloInstruction {
@@ -58,6 +59,7 @@ class HloInstruction {
kConvBackwardInput, // Fused into a backward input convolution.
};
+ ~HloInstruction();
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
const Shape& shape,
@@ -535,6 +537,11 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kFusion
HloInstruction* fused_expression_root() const;
+ // Returns the computation for this fused instruction.
+ //
+ // Precondition: opcode() == HloOpcode::kFusion
+ HloComputation* fused_instructions_computation() const;
+
// Returns the vector of fused instructions inside this fusion
// instruction. The order is a reverse postorder of the fused expression (root
// is first in the order).
@@ -719,10 +726,21 @@ class HloInstruction {
const HloComputation* parent() const { return parent_; }
HloComputation* parent() { return parent_; }
+ // Returns the module for this instruction.
+ HloModule* GetModule() const;
+
// Returns whether we could assign input and output layouts to this
// instruction to make it a bitcast.
bool CouldBeBitcast() const;
+ // Sets param_no for this Parameter instruction.
+ //
+ // Precondition: opcode() == HloOpcode::kParameter
+ void SetParentFusion(HloInstruction* fusion_instruction) {
+ CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
+ parent_fusion_instruction_ = fusion_instruction;
+ }
+
private:
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
@@ -808,22 +826,14 @@ class HloInstruction {
// padding of this pad instruction. Only set for pad instructions.
std::unique_ptr<PaddingConfig> padding_config_;
- // The set of instruction fused into this fusion instruction. Only set for
- // fusion instructions.
- std::list<std::unique_ptr<HloInstruction>> fused_instructions_;
+ // The computation that stores of instructions fused into this fusion
+ // instruction. Only set for fusion instructions.
+ std::unique_ptr<HloComputation> fused_instructions_computation_;
// If this instruction is fused into a fusion instruction, this field points
// to the fusion instruction.
HloInstruction* parent_fusion_instruction_ = nullptr;
- // The vector of parameter instructions inside this fusion instruction. The
- // index of the vector is the parameter_number of the parameter instruction.
- // This vector is non-empty only for fusion instructions.
- std::vector<HloInstruction*> fused_parameters_;
-
- // The root of the expression fused into this fusion instruction.
- HloInstruction* fused_root_ = nullptr;
-
// The type of the fusion. Used by kFusion only.
FusionKind fusion_kind_;
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 1ff5c5dacb..d76f08d3bc 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -106,6 +106,11 @@ class HloModule {
// Returns a randomly generated uint64.
uint64 RandomNew64() const;
+ // Returns the unique name for a computation in this module.
+ string GetUniqueCompuationName(const string& prefix) {
+ return computation_name_uniquer_.GetUniqueName(prefix);
+ }
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation);
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 035b570ed3..de6081e57e 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -23,7 +23,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RET_CHECK(instruction->parent() == computation.get());
if (instruction->opcode() == HloOpcode::kFusion) {
for (const auto& fused : instruction->fused_instructions()) {
- TF_RET_CHECK(fused->parent() == computation.get())
+ TF_RET_CHECK(fused->parent() ==
+ instruction->fused_instructions_computation())
<< "Fused HLO was missing a parent: " << fused->ToString()
<< " parent: " << fused->parent()
<< " computation: " << computation.get();