aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/pattern_matcher_test.cc
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-21 14:14:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 14:18:25 -0700
commitd0caa5a700dd36b7ac92be2722deaca9a4e23ef4 (patch)
treedfc057b9eda988c4b79be7e1dcb15408e2164d59 /tensorflow/compiler/xla/service/pattern_matcher_test.cc
parentf4de7ec889311c42b3af4d5f34f7d31f56f73177 (diff)
Ensure that no capture is done unless Match() return true. Otherwise the
application that relies on such behavior is hard to get right. To implement this, we need to be careful about AllOf, so that no capture is done unless all sub-patterns succeed. This leads to the solution that we have to run all patterns twice, first time with no captures, and second time to capture. PiperOrigin-RevId: 214042307
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index 7bd27268aa..d4e128bd70 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -318,5 +318,43 @@ TEST(PatternMatcherTest, AllOf) {
Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
}
+TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
+ using match::AllOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_FALSE(
+ Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
+ EXPECT_EQ(nullptr, constant);
+ ASSERT_TRUE(Match(root, Constant(&constant)));
+ EXPECT_NE(nullptr, constant);
+}
+
+TEST(PatternMatcherTest, TestNoCapture) {
+ using match::Constant;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
+ EXPECT_EQ(nullptr, constant);
+}
+
} // namespace
} // namespace xla