diff options
author | 2018-09-24 16:24:21 -0700 | |
---|---|---|
committer | 2018-09-24 16:28:11 -0700 | |
commit | 18a09eb548db25f6d82760105cf8e1fbbb1343a1 (patch) | |
tree | 7ede4f9253a72a9304667331b294b79e5d160ae0 /tensorflow/compiler/xla/service/pattern_matcher.h | |
parent | 6c40bc717442d56f0b6a60658b05f0549afd69ee (diff) |
Fix Hlo pattern matcher's AnyOf, so that a sub-pattern doesn't capture
when it's not matched.
Also add invariant checking for AllOf.
PiperOrigin-RevId: 214351269
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher.h')
-rw-r--r-- | tensorflow/compiler/xla/service/pattern_matcher.h | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 63b51fc8c9..52c6b51993 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -149,13 +149,19 @@ class AllOfPattern { explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} bool Match(const Item* item, MatchOption option) const { - return MatchImpl(item, option, - absl::make_index_sequence<sizeof...(Patterns)>()); + bool matched = MatchImpl(item, option, + absl::make_index_sequence<sizeof...(Patterns)>()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; } bool Match(Item* item, MatchOption option) const { - return MatchImpl(item, option, - absl::make_index_sequence<sizeof...(Patterns)>()); + bool matched = MatchImpl(item, option, + absl::make_index_sequence<sizeof...(Patterns)>()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; } private: @@ -307,8 +313,32 @@ class AnyOfPattern { template <typename ItemType, size_t index> bool MatchImpl(ItemType* item, MatchOption option, std::integral_constant<size_t, index>) const { - return std::get<index>(patterns_).Match(item, option) || - MatchImpl(item, option, std::integral_constant<size_t, index + 1>()); + auto new_option = option; + new_option.capture = false; + // Try to match the sub-pattern without capturing behavior. + if (std::get<index>(patterns_).Match(item, new_option)) { + // Capture the branch. + if (option.capture) { + // TODO(timshen): Currently the behavior can be exponential. Optimize it + // with memoization or recording the matched sub-pattern index, if it + // takes too long to run. + // + // Specifically, the "memoization" approach is to create an empty + // container with the key (pattern, instruction), and value as whether + // matched or not. + // + // Alternatively, we may run the pattern matching with captures off, but + // instead record a "trace" somewhere, indicating how exactly the + // pattern matches the input. For example, the trace information for + // AnyOf will be a runtime number indicate which sub-pattern is matched. + // Then we run another pass to do captures only with the help of the + // trace. + bool ret = std::get<index>(patterns_).Match(item, option); + DCHECK(ret); + } + return true; + } + return MatchImpl(item, option, std::integral_constant<size_t, index + 1>()); } template <typename ItemType> |