aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
index 393e758038..266039d2ff 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
@@ -196,5 +196,50 @@ ENTRY entry {
op::GetTupleElement(op::Parameter(0)),
op::GetTupleElement(op::Parameter(0))));
}
+
+TEST_F(WhileLoopConstantSinkingTest, DontCreateDeadConstant) {
+ const char* const hlo_string = R"(
+HloModule ModuleWithWhile
+
+body {
+ p_body = (f32[2],f32[2]) parameter(0)
+ p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0
+ p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1
+
+ outfeed = token[] outfeed(p_body.0)
+ ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1)
+}
+
+condition {
+ p_cond = (f32[2],f32[2]) parameter(0)
+ ROOT result = pred[] constant(true)
+}
+
+ENTRY entry {
+ const_0 = f32[2] constant({1, 2})
+ const_1 = f32[2] constant({2, 1})
+ while_init = (f32[2],f32[2]) tuple(const_0, const_1)
+ ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition,
+ body=body
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_string));
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ WhileLoopConstantSinking{}.Run(module.get()));
+ ASSERT_TRUE(changed);
+
+ auto* while_body = module->GetComputationWithName("body");
+ EXPECT_THAT(while_body->root_instruction(),
+ op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ for (const HloInstruction* inst : while_body->instructions()) {
+ if (inst->opcode() == HloOpcode::kConstant) {
+ EXPECT_GT(inst->user_count(), 0);
+ }
+ }
+}
} // namespace
} // namespace xla