diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/pattern_matcher_test.cc | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index fef3c132b0..a530581c34 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -193,5 +193,23 @@ TEST(PatternMatcherTest, FusionKind) { HloInstruction::FusionKind::kLoop))); } +TEST(PatternMatcherTest, GetTupleElement) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + ENTRY while.v11 { + p0 = (f32[], f32[], f32[]) parameter(0) + ROOT gte = f32[] get-tuple-element(p0), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0))); + EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1))); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2))); + EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0))); + EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); +} + } // namespace } // namespace xla |