diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-05-25 17:22:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-25 17:24:56 -0700 |
commit | 8fcc95ebf42ed8eea543ec2edf1a1ed1c62ca7e8 (patch) | |
tree | 66758f11ac719ff941050e4bb5182c1dfea2c4c0 /tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h | |
parent | 06717b77e05bd602d10fe40f4519dbb105fabd5c (diff) |
Enable while loop constant sinking for GPU
To avoid keeping constants in while loop bodies after optimization (where they
may cause extra copies) we run a late pass of LICM that has been asked to hoist
constants when it can.
PiperOrigin-RevId: 198126497
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h')
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8c4b765b00..8e6cc87875 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -27,12 +27,28 @@ namespace xla { class WhileLoopInvariantCodeMotion : public HloPassInterface { public: + // If `hoist_constants` is true then constants are always hoisted out of while + // loop bodies. Otherwise they are only hoisted out if they enable other + // non-trivial computations to be hoisted out. + // + // Setting `hoist_constants` to false can be help if LICM is run in the mid + // level HLO pipeline because hoisting constants out of while loop bodies can + // break optimizations like constant folding. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) + : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; tensorflow::StringPiece name() const override { return "while-loop-invariant-code-motion"; } StatusOr<bool> Run(HloModule* module) override; + + private: + bool NotWorthHoistingIndividually(const HloInstruction& instruction); + StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + bool hoist_constants_; }; } // namespace xla |