aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/pattern_matcher_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index fef3c132b0..a530581c34 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -193,5 +193,23 @@ TEST(PatternMatcherTest, FusionKind) {
HloInstruction::FusionKind::kLoop)));
}
+TEST(PatternMatcherTest, GetTupleElement) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+
+ ENTRY while.v11 {
+ p0 = (f32[], f32[], f32[]) parameter(0)
+ ROOT gte = f32[] get-tuple-element(p0), index=1
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+
+ auto* root = hlo_module->entry_computation()->root_instruction();
+ EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
+ EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
+ EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
+ EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
+ EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
+}
+
} // namespace
} // namespace xla