diff options
author | 2018-09-24 16:57:45 -0700 | |
---|---|---|
committer | 2018-09-24 17:01:14 -0700 | |
commit | cc5555d3d3daa64f462cc7f8d31fe915073429f4 (patch) | |
tree | d761287da2f3a5d2dda5234058bed3184817e092 /tensorflow/compiler/xla/service/pattern_matcher.h | |
parent | 7a1096f424b1adcb4152db80a01a163ddb1a0173 (diff) |
Short-circuit AllOf as well. This fixes a crash in ConstantScalar, as it
uses Cast internally.
PiperOrigin-RevId: 214356411
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher.h')
-rw-r--r-- | tensorflow/compiler/xla/service/pattern_matcher.h | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 52c6b51993..380cde0e6a 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -149,27 +149,31 @@ class AllOfPattern { explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} bool Match(const Item* item, MatchOption option) const { - bool matched = MatchImpl(item, option, - absl::make_index_sequence<sizeof...(Patterns)>()); + bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>()); // This invariant is guaranteed by the top-level Match and AnyOf. DCHECK(matched || !option.capture); return matched; } bool Match(Item* item, MatchOption option) const { - bool matched = MatchImpl(item, option, - absl::make_index_sequence<sizeof...(Patterns)>()); + bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>()); // This invariant is guaranteed by the top-level Match and AnyOf. DCHECK(matched || !option.capture); return matched; } private: - template <typename ItemType, size_t... indices> + template <typename ItemType, size_t index> bool MatchImpl(ItemType* item, MatchOption option, - absl::index_sequence<indices...>) const { - return std::min<bool>( - {std::get<indices>(patterns_).Match(item, option)...}); + std::integral_constant<size_t, index>) const { + return std::get<index>(patterns_).Match(item, option) && + MatchImpl(item, option, std::integral_constant<size_t, index + 1>()); + } + + template <typename ItemType> + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant<size_t, sizeof...(Patterns)>) const { + return true; } std::tuple<Patterns...> patterns_; |