diff options
author | 2018-06-25 16:35:07 -0700 | |
---|---|---|
committer | 2018-06-25 16:38:29 -0700 | |
commit | b33b26cf35230cfe1875509dd2d7ff8a2cf6c581 (patch) | |
tree | b3cc5c0724850c8de34a45dd1a257febc475099f /tensorflow/compiler/xla/service/conditional_simplifier_test.cc | |
parent | 23d602a7da399ded85044a82235ef8cf22ef2be6 (diff) |
Change infeed and outfeed to take and produce tokens.
Tokens are primitive types which can be threaded between side-effecting operations to order them. This CL changes infeed and outfeed to take a token as an operands and produce a token as one of its outputs. The most disruptive aspect of this change is that infeed now produces a two-element tuple containing the data value and a token. This means the shape of infed data no longer is the same as the shape of the infeed instruction, and a get-tuple-element operation must be called on the infeed instructions output to get its data.
Related changes/notes:
- The computation builder interface is unchanged. The infeed builder constructs an infeed instruction followed by a GTE instruction to extract the data value. Client and computation builder interface changes will be in follow up cls.
- Tokens can now be the root of the entry computation. Previously tokens could not be passed into or out of the entry computation. But now that outfeed produces a token, this constraint meant that outfeed could not be a root which is awkward. In the future we'd like to pass in tokens as well, perhaps as the only way of generating the initial token to thread through side-effecting ops.
- Infeed and outfeed still have a form which does not take a token to minimize the size of this CL. In the future this form will be removed. However, most HLO tests using infeed/outfeed are changed to accept a token in this cl.
PiperOrigin-RevId: 202041518
Diffstat (limited to 'tensorflow/compiler/xla/service/conditional_simplifier_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/conditional_simplifier_test.cc | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 868348547d..cad767c039 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -144,8 +144,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); - false_computation->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = false_computation->AddInstruction( + HloInstruction::CreateGenerateToken({})); + false_computation->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } |