aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-25 17:22:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-25 17:24:56 -0700
commit8fcc95ebf42ed8eea543ec2edf1a1ed1c62ca7e8 (patch)
tree66758f11ac719ff941050e4bb5182c1dfea2c4c0 /tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
parent06717b77e05bd602d10fe40f4519dbb105fabd5c (diff)
Enable while loop constant sinking for GPU
To avoid keeping constants in while loop bodies after optimization (where they may cause extra copies) we run a late pass of LICM that has been asked to hoist constants when it can. PiperOrigin-RevId: 198126497
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc27
1 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 321fdeb1ea..09ddcffb22 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy(
// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
-static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
+bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
+ const HloInstruction& instruction) {
switch (instruction.opcode()) {
default:
return false;
+ case HloOpcode::kConstant:
+ return !hoist_constants_;
+
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
- case HloOpcode::kConstant:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
@@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
}
}
-static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+StatusOr<bool>
+WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
HloInstruction* while_instr) {
auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
@@ -161,12 +165,16 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
}
- if (unhoisted_invariant_instructions.empty()) {
+ if (unhoisted_invariant_instructions.empty() && !hoist_constants_) {
// There are no obviously loop invariant elements in the state being
// threaded through the while loop so give up. In theory this precondition
// is too strong -- we could have code that e.g. permutes the elements in
// the while state but uses a select to pick the same value on every
// iteration.
+ //
+ // If we were asked to hoist constants, we need to scan the while body for
+ // constants even if we didn't find any loop invariant values in the while
+ // state tuple.
return false;
}
@@ -243,6 +251,9 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
+ VLOG(2) << "HLO module before WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
@@ -270,6 +281,14 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
TryHoistingInvariantInstructionsFromWhileBody(while_instr));
changed |= result;
}
+
+ if (changed) {
+ VLOG(2) << "HLO module after WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+ } else {
+ VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
+ }
+
return changed;
}
} // namespace xla