diff options
author | Mark Heffernan <meheff@google.com> | 2018-06-25 16:35:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-25 16:38:29 -0700 |
commit | b33b26cf35230cfe1875509dd2d7ff8a2cf6c581 (patch) | |
tree | b3cc5c0724850c8de34a45dd1a257febc475099f /tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc | |
parent | 23d602a7da399ded85044a82235ef8cf22ef2be6 (diff) |
Change infeed and outfeed to take and produce tokens.
Tokens are primitive types which can be threaded between side-effecting operations to order them. This CL changes infeed and outfeed to take a token as an operands and produce a token as one of its outputs. The most disruptive aspect of this change is that infeed now produces a two-element tuple containing the data value and a token. This means the shape of infed data no longer is the same as the shape of the infeed instruction, and a get-tuple-element operation must be called on the infeed instructions output to get its data.
Related changes/notes:
- The computation builder interface is unchanged. The infeed builder constructs an infeed instruction followed by a GTE instruction to extract the data value. Client and computation builder interface changes will be in follow up cls.
- Tokens can now be the root of the entry computation. Previously tokens could not be passed into or out of the entry computation. But now that outfeed produces a token, this constraint meant that outfeed could not be a root which is awkward. In the future we'd like to pass in tokens as well, perhaps as the only way of generating the initial token to thread through side-effecting ops.
- Infeed and outfeed still have a form which does not take a token to minimize the size of this CL. In the future this form will be removed. However, most HLO tests using infeed/outfeed are changed to accept a token in this cl.
PiperOrigin-RevId: 202041518
Diffstat (limited to 'tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc | 39 |
1 files changed, 29 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8831c513ee..56a3778efd 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -248,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -258,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); - + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, WhileLoopInvariantCodeMotion{}.Run(&module())); - EXPECT_FALSE(simplified_loop); + ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Outfeed())); @@ -287,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -297,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); HloInstruction* bitcast_inst = builder.AddInstruction( HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); |