aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-25 17:22:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-25 17:24:56 -0700
commit8fcc95ebf42ed8eea543ec2edf1a1ed1c62ca7e8 (patch)
tree66758f11ac719ff941050e4bb5182c1dfea2c4c0 /tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
parent06717b77e05bd602d10fe40f4519dbb105fabd5c (diff)
Enable while loop constant sinking for GPU
To avoid keeping constants in while loop bodies after optimization (where they may cause extra copies) we run a late pass of LICM that has been asked to hoist constants when it can. PiperOrigin-RevId: 198126497
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 8c4b765b00..8e6cc87875 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -27,12 +27,28 @@ namespace xla {
class WhileLoopInvariantCodeMotion : public HloPassInterface {
public:
+ // If `hoist_constants` is true then constants are always hoisted out of while
+ // loop bodies. Otherwise they are only hoisted out if they enable other
+ // non-trivial computations to be hoisted out.
+ //
+ // Setting `hoist_constants` to false can be help if LICM is run in the mid
+ // level HLO pipeline because hoisting constants out of while loop bodies can
+ // break optimizations like constant folding.
+ explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false)
+ : hoist_constants_(hoist_constants) {}
~WhileLoopInvariantCodeMotion() override = default;
tensorflow::StringPiece name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ bool NotWorthHoistingIndividually(const HloInstruction& instruction);
+ StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+ HloInstruction* while_instr);
+
+ bool hoist_constants_;
};
} // namespace xla