diff options
author | 2018-05-25 17:22:11 -0700 | |
---|---|---|
committer | 2018-05-25 17:24:56 -0700 | |
commit | 8fcc95ebf42ed8eea543ec2edf1a1ed1c62ca7e8 (patch) | |
tree | 66758f11ac719ff941050e4bb5182c1dfea2c4c0 /tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc | |
parent | 06717b77e05bd602d10fe40f4519dbb105fabd5c (diff) |
Enable while loop constant sinking for GPU
To avoid keeping constants in while loop bodies after optimization (where they
may cause extra copies) we run a late pass of LICM that has been asked to hoist
constants when it can.
PiperOrigin-RevId: 198126497
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 321fdeb1ea..09ddcffb22 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy( // Returns true if `instruction` is worth hoisting only if it lets us hoist some // instruction using it. The rationale is that hoisting these instructions will // prevent simplification and fusion in the while body. -static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { +bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( + const HloInstruction& instruction) { switch (instruction.opcode()) { default: return false; + case HloOpcode::kConstant: + return !hoist_constants_; + case HloOpcode::kBitcast: case HloOpcode::kBroadcast: - case HloOpcode::kConstant: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( +StatusOr<bool> +WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -161,12 +165,16 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( } } - if (unhoisted_invariant_instructions.empty()) { + if (unhoisted_invariant_instructions.empty() && !hoist_constants_) { // There are no obviously loop invariant elements in the state being // threaded through the while loop so give up. In theory this precondition // is too strong -- we could have code that e.g. permutes the elements in // the while state but uses a select to pick the same value on every // iteration. + // + // If we were asked to hoist constants, we need to scan the while body for + // constants even if we didn't find any loop invariant values in the while + // state tuple. return false; } @@ -243,6 +251,9 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( } StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; std::vector<HloInstruction*> while_instrs; for (auto* comp : module->computations()) { @@ -270,6 +281,14 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) { TryHoistingInvariantInstructionsFromWhileBody(while_instr)); changed |= result; } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + return changed; } } // namespace xla |