aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-30 13:19:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 13:23:54 -0700
commite565d1f1fced69789feb10f1ea1241157ec95f93 (patch)
tree0351aafe956f71c30f6aa27da7917bc10a0afdf0 /tensorflow/compiler/xla/service/hlo_computation.h
parent79603e4dd92b182ca0689da97d4d992d1d2b1316 (diff)
[XLA] Refactor parent-fusion-instruction pointer into HloComputation, not HloInstruction.
Presently, each instruction inside a fusion computation contains a pointer to the fusion instruction that contains the computation, which is redundant since this is common across the entire computation. This leads to lots of places where this pointer must be set when adding an instruction to the fusion computation (and bugs such as b/65177535 when one is missed), as well as code to check that it's set correctly. In addition, this is simply unnecessary data bloat. Moreover, the computation itself does not contain a pointer to the fusion instruction that references it, which leads to odd circumlocutions in the HloComputation code that retrieve the fusion instruction from the computation's root instruction. Thus, this CL moves this pointer into the HloComputation class (replacing the is_fusion_computation_ bool value), and refactor the uses as necessary. PiperOrigin-RevId: 167039280
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h21
1 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index f383a17fb8..576c44a9f3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -56,10 +56,11 @@ class HloComputation {
// Builder class for HloComputation.
class Builder {
public:
- explicit Builder(const string& name, bool is_fusion_computation = false)
+ explicit Builder(const string& name,
+ HloInstruction* fusion_instruction = nullptr)
: name_(name),
last_added_instruction_(nullptr),
- is_fusion_computation_(is_fusion_computation) {}
+ fusion_instruction_(fusion_instruction) {}
// Build and return an HloComputation. The parameter root_instruction
// specifies the already-added instruction to use as the root. If
@@ -78,7 +79,7 @@ class HloComputation {
private:
const string name_;
HloInstruction* last_added_instruction_;
- bool is_fusion_computation_;
+ HloInstruction* fusion_instruction_;
std::vector<std::unique_ptr<HloInstruction>> instructions_;
};
@@ -274,13 +275,18 @@ class HloComputation {
bool HasSideEffect() const;
// Returns if this computation is a fusion computation.
- bool IsFusionComputation() const { return is_fusion_computation_; }
+ bool IsFusionComputation() const { return fusion_instruction_ != nullptr; }
+
+ // Returns the owning fusion instruction, or nullptr if this is not a fusion
+ // computation.
+ HloInstruction* FusionInstruction() const { return fusion_instruction_; }
private:
explicit HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
- HloInstruction* root_instruction, bool is_fusion_computation = false);
+ HloInstruction* root_instruction,
+ HloInstruction* fusion_instruction = nullptr);
// Internal helper for adding instructions.
HloInstruction* AddInstructionInternal(
@@ -309,8 +315,9 @@ class HloComputation {
string name_;
HloInstruction* root_instruction_;
- // A tag shows if this is a fusion computation.
- bool is_fusion_computation_;
+ // If this computation is a fusion computation, this field points to the
+ // corresponding fusion instruction. Otherwise, this is null.
+ HloInstruction* fusion_instruction_;
// Module containing this computation.
HloModule* parent_ = nullptr;