aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/pattern_matcher.h
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-24 16:24:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 16:28:11 -0700
commit18a09eb548db25f6d82760105cf8e1fbbb1343a1 (patch)
tree7ede4f9253a72a9304667331b294b79e5d160ae0 /tensorflow/compiler/xla/service/pattern_matcher.h
parent6c40bc717442d56f0b6a60658b05f0549afd69ee (diff)
Fix Hlo pattern matcher's AnyOf, so that a sub-pattern doesn't capture
when it's not matched. Also add invariant checking for AllOf. PiperOrigin-RevId: 214351269
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher.h')
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h42
1 files changed, 36 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 63b51fc8c9..52c6b51993 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -149,13 +149,19 @@ class AllOfPattern {
explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
bool Match(const Item* item, MatchOption option) const {
- return MatchImpl(item, option,
- absl::make_index_sequence<sizeof...(Patterns)>());
+ bool matched = MatchImpl(item, option,
+ absl::make_index_sequence<sizeof...(Patterns)>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
}
bool Match(Item* item, MatchOption option) const {
- return MatchImpl(item, option,
- absl::make_index_sequence<sizeof...(Patterns)>());
+ bool matched = MatchImpl(item, option,
+ absl::make_index_sequence<sizeof...(Patterns)>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
}
private:
@@ -307,8 +313,32 @@ class AnyOfPattern {
template <typename ItemType, size_t index>
bool MatchImpl(ItemType* item, MatchOption option,
std::integral_constant<size_t, index>) const {
- return std::get<index>(patterns_).Match(item, option) ||
- MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
+ auto new_option = option;
+ new_option.capture = false;
+ // Try to match the sub-pattern without capturing behavior.
+ if (std::get<index>(patterns_).Match(item, new_option)) {
+ // Capture the branch.
+ if (option.capture) {
+ // TODO(timshen): Currently the behavior can be exponential. Optimize it
+ // with memoization or recording the matched sub-pattern index, if it
+ // takes too long to run.
+ //
+ // Specifically, the "memoization" approach is to create an empty
+ // container with the key (pattern, instruction), and value as whether
+ // matched or not.
+ //
+ // Alternatively, we may run the pattern matching with captures off, but
+ // instead record a "trace" somewhere, indicating how exactly the
+ // pattern matches the input. For example, the trace information for
+ // AnyOf will be a runtime number indicate which sub-pattern is matched.
+ // Then we run another pass to do captures only with the help of the
+ // trace.
+ bool ret = std::get<index>(patterns_).Match(item, option);
+ DCHECK(ret);
+ }
+ return true;
+ }
+ return MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
}
template <typename ItemType>