aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 0536c99b67..2e1571943e 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -157,7 +157,7 @@ TEST_F(WhileLoopSimplifierTest,
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* true_op = while_op->while_body()->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
TF_ASSERT_OK(true_op->AddControlDependencyTo(
while_op->while_body()->root_instruction()));
ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -175,9 +175,11 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
+ auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
while_body->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
+ token,
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -190,8 +192,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
+ auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
auto* recv = while_body->AddInstruction(
- HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
+ HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token,
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -208,7 +211,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- auto token = while_body->AddInstruction(HloInstruction::CreateAfterAll({}));
+ auto token = while_body->AddInstruction(HloInstruction::CreateToken());
while_body->AddInstruction(HloInstruction::CreateInfeed(
ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());