diff options
author | Tim Shen <timshen@google.com> | 2018-09-21 14:14:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 14:18:25 -0700 |
commit | d0caa5a700dd36b7ac92be2722deaca9a4e23ef4 (patch) | |
tree | dfc057b9eda988c4b79be7e1dcb15408e2164d59 /tensorflow/compiler/xla/service/pattern_matcher_test.cc | |
parent | f4de7ec889311c42b3af4d5f34f7d31f56f73177 (diff) |
Ensure that no capture is done unless Match() return true. Otherwise the
application that relies on such behavior is hard to get right.
To implement this, we need to be careful about AllOf, so that no capture
is done unless all sub-patterns succeed. This leads to the solution that
we have to run all patterns twice, first time with no captures, and
second time to capture.
PiperOrigin-RevId: 214042307
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/pattern_matcher_test.cc | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 7bd27268aa..d4e128bd70 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -318,5 +318,43 @@ TEST(PatternMatcherTest, AllOf) { Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern))); } +TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_FALSE( + Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op())))); + EXPECT_EQ(nullptr, constant); + ASSERT_TRUE(Match(root, Constant(&constant))); + EXPECT_NE(nullptr, constant); +} + +TEST(PatternMatcherTest, TestNoCapture) { + using match::Constant; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false})); + EXPECT_EQ(nullptr, constant); +} + } // namespace } // namespace xla |