aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-06-25 16:35:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 16:38:29 -0700
commitb33b26cf35230cfe1875509dd2d7ff8a2cf6c581 (patch)
treeb3cc5c0724850c8de34a45dd1a257febc475099f /tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
parent23d602a7da399ded85044a82235ef8cf22ef2be6 (diff)
Change infeed and outfeed to take and produce tokens.
Tokens are primitive types which can be threaded between side-effecting operations to order them. This CL changes infeed and outfeed to take a token as an operands and produce a token as one of its outputs. The most disruptive aspect of this change is that infeed now produces a two-element tuple containing the data value and a token. This means the shape of infed data no longer is the same as the shape of the infeed instruction, and a get-tuple-element operation must be called on the infeed instructions output to get its data. Related changes/notes: - The computation builder interface is unchanged. The infeed builder constructs an infeed instruction followed by a GTE instruction to extract the data value. Client and computation builder interface changes will be in follow up cls. - Tokens can now be the root of the entry computation. Previously tokens could not be passed into or out of the entry computation. But now that outfeed produces a token, this constraint meant that outfeed could not be a root which is awkward. In the future we'd like to pass in tokens as well, perhaps as the only way of generating the initial token to thread through side-effecting ops. - Infeed and outfeed still have a form which does not take a token to minimize the size of this CL. In the future this form will be removed. However, most HLO tests using infeed/outfeed are changed to accept a token in this cl. PiperOrigin-RevId: 202041518
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc39
1 files changed, 29 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 8831c513ee..56a3778efd 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -248,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest,
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
- Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ Shape while_shape =
+ ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
@@ -258,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ HloInstruction* in_token = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(token_shape, param, 2));
+ HloInstruction* out_token = builder.AddInstruction(
+ HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, ""));
builder.AddInstruction(
- HloInstruction::CreateOutfeed(scalar_s32, gte_0, ""));
- builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
+ HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
+ auto* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_s32, "param"));
+ auto* token = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
auto* init_value = builder.AddInstruction(
- HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
-
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
- EXPECT_FALSE(simplified_loop);
+ ASSERT_FALSE(simplified_loop);
EXPECT_THAT(while_inst->while_body()->instructions(),
Contains(op::Outfeed()));
@@ -287,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
// bitcast either.
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
- Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ Shape while_shape =
+ ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
@@ -297,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ HloInstruction* in_token = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(token_shape, param, 2));
HloInstruction* bitcast_inst = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
+ HloInstruction* out_token = builder.AddInstruction(
+ HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, ""));
builder.AddInstruction(
- HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, ""));
- builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
+ HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
+ auto* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_s32, "param"));
+ auto* token = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
auto* init_value = builder.AddInstruction(
- HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
module().AddEntryComputation(builder.Build());