aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_constant_sinking.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.cc6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index 10fc4958fa..62af45128a 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -61,6 +61,12 @@ StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody(
WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) {
int64 index = invariant_gte->tuple_index();
const HloInstruction& invariant_value = *init_value.operand(index);
+
+ // Should have at least one user that's not while_body_root.
+ if (invariant_gte->user_count() <= 1) {
+ continue;
+ }
+
if (invariant_value.opcode() == HloOpcode::kConstant) {
auto* constant_instr =
while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk"));