diff options
author | David Majnemer <majnemer@google.com> | 2017-03-29 13:26:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-29 14:48:18 -0700 |
commit | 5149785eb7175a791acbd9859872e07439b968b6 (patch) | |
tree | 301dec5bfebc225557ce25638c41dfde99fdbaad /tensorflow/compiler/xla/service/instruction_fusion.h | |
parent | cd021175181431c57fdebe0a82e99ffabcc0897f (diff) |
Fusion uses the same cost model for all backends when it comes to
rematerializing for fusion. Let subtypes override the default cost model.
Change: 151626001
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion.h')
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion.h | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index b8fd3dd4f3..a9f3723f2d 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -24,15 +24,6 @@ limitations under the License. namespace xla { -// Returns true if the computation of the given instruction is significantly -// more expensive than just writing all the values of the instructions' result -// array. Expensive operations should not be duplicated. -bool IsExpensive(const HloInstruction& instruction); - -// Returns true if fusing producer into consumer would cause producer to be -// duplicated. This is the case if producer has uses other than consumer. -bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer); - // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -40,15 +31,22 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer); // instructions to fuse. class InstructionFusion : public HloPassInterface { public: - explicit InstructionFusion(bool may_duplicate = true) - : may_duplicate_(may_duplicate) {} - ~InstructionFusion() override {} + explicit InstructionFusion( + std::function<bool(const HloInstruction& instruction)> is_expensive, + bool may_duplicate = true) + : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} + ~InstructionFusion() override = default; tensorflow::StringPiece name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). StatusOr<bool> Run(HloModule* module) override; + // Returns true if the computation of the given instruction is significantly + // more expensive than just writing all the values of the instructions' result + // array. Expensive operations will not be duplicated. + static bool IsExpensive(const HloInstruction& instruction); + protected: // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. @@ -74,6 +72,10 @@ class InstructionFusion : public HloPassInterface { private: HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + // Used to determine if an HLO is expensive. Expensive operations will not be + // duplicated. + std::function<bool(const HloInstruction& instruction)> is_expensive_; + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; |