aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/instruction_fusion.h
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-03-29 13:26:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-29 14:48:18 -0700
commit5149785eb7175a791acbd9859872e07439b968b6 (patch)
tree301dec5bfebc225557ce25638c41dfde99fdbaad /tensorflow/compiler/xla/service/instruction_fusion.h
parentcd021175181431c57fdebe0a82e99ffabcc0897f (diff)
Fusion uses the same cost model for all backends when it comes to
rematerializing for fusion. Let subtypes override the default cost model. Change: 151626001
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion.h')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h26
1 files changed, 14 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index b8fd3dd4f3..a9f3723f2d 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -24,15 +24,6 @@ limitations under the License.
namespace xla {
-// Returns true if the computation of the given instruction is significantly
-// more expensive than just writing all the values of the instructions' result
-// array. Expensive operations should not be duplicated.
-bool IsExpensive(const HloInstruction& instruction);
-
-// Returns true if fusing producer into consumer would cause producer to be
-// duplicated. This is the case if producer has uses other than consumer.
-bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
-
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
@@ -40,15 +31,22 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
// instructions to fuse.
class InstructionFusion : public HloPassInterface {
public:
- explicit InstructionFusion(bool may_duplicate = true)
- : may_duplicate_(may_duplicate) {}
- ~InstructionFusion() override {}
+ explicit InstructionFusion(
+ std::function<bool(const HloInstruction& instruction)> is_expensive,
+ bool may_duplicate = true)
+ : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
+ ~InstructionFusion() override = default;
tensorflow::StringPiece name() const override { return "fusion"; }
// Run instruction fusion on the given computation. Returns whether the
// computation was changed (instructions were fused).
StatusOr<bool> Run(HloModule* module) override;
+ // Returns true if the computation of the given instruction is significantly
+ // more expensive than just writing all the values of the instructions' result
+ // array. Expensive operations will not be duplicated.
+ static bool IsExpensive(const HloInstruction& instruction);
+
protected:
// Returns whether the given producer instruction should be fused into the
// given consumer instruction. producer is necessarily an operand of consumer.
@@ -74,6 +72,10 @@ class InstructionFusion : public HloPassInterface {
private:
HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);
+ // Used to determine if an HLO is expensive. Expensive operations will not be
+ // duplicated.
+ std::function<bool(const HloInstruction& instruction)> is_expensive_;
+
// Returns whether we may duplicate an instruction if we want to fuse it.
bool may_duplicate_;