aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-06-25 16:35:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 16:38:29 -0700
commitb33b26cf35230cfe1875509dd2d7ff8a2cf6c581 (patch)
treeb3cc5c0724850c8de34a45dd1a257febc475099f /tensorflow/compiler/xla/service/hlo_parser_test.cc
parent23d602a7da399ded85044a82235ef8cf22ef2be6 (diff)
Change infeed and outfeed to take and produce tokens.
Tokens are primitive types which can be threaded between side-effecting operations to order them. This CL changes infeed and outfeed to take a token as an operands and produce a token as one of its outputs. The most disruptive aspect of this change is that infeed now produces a two-element tuple containing the data value and a token. This means the shape of infed data no longer is the same as the shape of the infeed instruction, and a get-tuple-element operation must be called on the infeed instructions output to get its data. Related changes/notes: - The computation builder interface is unchanged. The infeed builder constructs an infeed instruction followed by a GTE instruction to extract the data value. Client and computation builder interface changes will be in follow up cls. - Tokens can now be the root of the entry computation. Previously tokens could not be passed into or out of the entry computation. But now that outfeed produces a token, this constraint meant that outfeed could not be a root which is awkward. In the future we'd like to pass in tokens as well, perhaps as the only way of generating the initial token to thread through side-effecting ops. - Infeed and outfeed still have a form which does not take a token to minimize the size of this CL. In the future this form will be removed. However, most HLO tests using infeed/outfeed are changed to accept a token in this cl. PiperOrigin-RevId: 202041518
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc22
1 files changed, 18 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index d481e07f60..5ec9225a68 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -795,10 +795,14 @@ ENTRY ReduceR3ToR2.v3 {
R"(HloModule outfeed_module
ENTRY InfeedToOutfeed {
- infeed = (u32[3]{0}, pred[]) infeed()
- outfeed = () outfeed(infeed)
- ROOT infeed.1 = (u32[3]{0}, pred[]) infeed()
- outfeed.1 = () outfeed(infeed.1)
+ token = token[] generate-token()
+ infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
+ ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
+ infeed.1.token = token[] get-tuple-element(infeed.1), index=1
+ outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
}
)"
@@ -1418,5 +1422,15 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
+TEST_F(HloParserTest, NontupleInfeed) {
+ const string original = R"(HloModule nontuple_infeed:
+ENTRY nontuple_infeed {
+ token = token[] generate-token()
+ ROOT infeed = pred[] infeed(token)
+})";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "infeed must have a non-empty tuple shape");
+}
+
} // namespace
} // namespace xla