aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc22
1 files changed, 18 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index d481e07f60..5ec9225a68 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -795,10 +795,14 @@ ENTRY ReduceR3ToR2.v3 {
R"(HloModule outfeed_module
ENTRY InfeedToOutfeed {
- infeed = (u32[3]{0}, pred[]) infeed()
- outfeed = () outfeed(infeed)
- ROOT infeed.1 = (u32[3]{0}, pred[]) infeed()
- outfeed.1 = () outfeed(infeed.1)
+ token = token[] generate-token()
+ infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
+ ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
+ infeed.1.token = token[] get-tuple-element(infeed.1), index=1
+ outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
}
)"
@@ -1418,5 +1422,15 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
+TEST_F(HloParserTest, NontupleInfeed) {
+ const string original = R"(HloModule nontuple_infeed:
+ENTRY nontuple_infeed {
+ token = token[] generate-token()
+ ROOT infeed = pred[] infeed(token)
+})";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "infeed must have a non-empty tuple shape");
+}
+
} // namespace
} // namespace xla