diff options
author | Tim Shen <timshen@google.com> | 2018-09-18 16:58:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 17:06:21 -0700 |
commit | 93b5dea9663c00d3bb06348143b50b73b6fbacfb (patch) | |
tree | a87ca501119eb76a868fc9f83381e28a5369cfe0 | |
parent | f7b54ae1b4b215b2944e232ca51604aad1356930 (diff) |
Add ConstantScalar, WithPredicate, Disjunction, and OpAnyOrder (where Op
is a commutative binary operator) to the XLA pattern matcher.
PiperOrigin-RevId: 213543953
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/pattern_matcher.h | 143 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/pattern_matcher_test.cc | 84 |
3 files changed, 222 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index fb80c78f68..68bf56c1b1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -365,8 +365,11 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/utility", ], ) diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 4869db79e7..7d4d62ecb9 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -17,8 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #include "absl/strings/string_view.h" +#include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -228,8 +232,46 @@ class LayoutPattern { LayoutType** matched_layout_; }; +template <typename Item, typename... Patterns> +class AnyOfPattern { + public: + explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item) const { + return MatchImpl(item, std::integral_constant<size_t, 0>()); + } + + bool Match(Item* item) const { + return MatchImpl(item, std::integral_constant<size_t, 0>()); + } + + private: + template <typename ItemType, size_t index> + bool MatchImpl(ItemType* item, std::integral_constant<size_t, index>) const { + return std::get<index>(patterns_).Match(item) || + MatchImpl(item, std::integral_constant<size_t, index + 1>()); + } + + template <typename ItemType> + bool MatchImpl(ItemType* item, + std::integral_constant<size_t, sizeof...(Patterns)>) const { + return false; + } + + std::tuple<Patterns...> patterns_; +}; } // namespace detail +// Returns a pattern that represents the logical disjunction of the input +// patterns. The returned pattern matches from left to right, and stops on the +// first match. +template <typename Item, typename... Patterns> +detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf( + const Patterns&... patterns) { + return detail::AnyOfPattern<typename std::remove_const<Item>::type, + Patterns...>(patterns...); +} + // Creates a layout pattern that will capture the matched layout in the // argument. inline constexpr detail::LayoutPattern<const ::xla::Layout, @@ -752,6 +794,27 @@ class HloInstructionPatternTupleIndexImpl { int64 tuple_index_; }; +template <typename Previous, typename ItemType, typename Predicate> +class HloPredicatePatternImpl { + public: + explicit HloPredicatePatternImpl(const Previous& previous, Predicate pred) + : previous_(previous), pred_(std::move(pred)) {} + + bool Match(const ItemType* item) const { + return previous_.Match(item) && pred_(item); + } + + bool Match(ItemType* item) const { + return previous_.Match(item) && pred_(item); + } + + private: + Previous previous_; + Predicate pred_; +}; + +struct PatternFriend; + // A pattern that matches HloInstructions. template <typename HloInstructionType, typename Impl> class HloInstructionPattern { @@ -879,6 +942,21 @@ class HloInstructionPattern { } private: + template <typename Predicate> + constexpr HloInstructionPattern< + HloInstructionType, + HloPredicatePatternImpl< + Impl, typename std::remove_const<HloInstructionType>::type, + Predicate>> + WithPredicate(Predicate pred) const { + using NewImplType = HloPredicatePatternImpl< + Impl, typename std::remove_const<HloInstructionType>::type, Predicate>; + return HloInstructionPattern<HloInstructionType, NewImplType>( + NewImplType(impl_, std::move(pred)), matched_inst_); + } + + friend struct PatternFriend; + Impl impl_; HloInstructionType** matched_inst_; }; @@ -1005,31 +1083,50 @@ XLA_UNOP_PATTERN(Transpose) .WithOperand(0, std::forward<Lhs>(lhs)) \ .WithOperand(1, std::forward<Rhs>(rhs)); \ } -XLA_BINOP_PATTERN(Add) + +#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ + XLA_BINOP_PATTERN(NAME) \ + \ + template <typename Lhs, typename Rhs> \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) { \ + return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs)); \ + } \ + \ + template <typename HloInstructionType, typename Lhs, typename Rhs> \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs))) { \ + return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs)); \ + } +XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(Eq) +XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) XLA_BINOP_PATTERN(Gt) XLA_BINOP_PATTERN(Le) XLA_BINOP_PATTERN(Lt) -XLA_BINOP_PATTERN(Maximum) -XLA_BINOP_PATTERN(Minimum) -XLA_BINOP_PATTERN(Multiply) -XLA_BINOP_PATTERN(Ne) +XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) +XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) +XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) +XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) -XLA_BINOP_PATTERN(And) -XLA_BINOP_PATTERN(Or) +XLA_COMMUTATIVE_BINOP_PATTERN(And) +XLA_COMMUTATIVE_BINOP_PATTERN(Or) XLA_BINOP_PATTERN(ShiftLeft) XLA_BINOP_PATTERN(ShiftRightArithmetic) XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_COMMUTATIVE_BINOP_PATTERN #undef XLA_BINOP_PATTERN // Helpers for ternary instructions. @@ -1070,6 +1167,30 @@ XLA_TERNOP_PATTERN(Clamp); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN +namespace detail { +struct PatternFriend { + template <typename T> + static auto ConstantScalar(T constant) -> decltype( + Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate( + std::declval<std::function<bool(const HloInstruction*)>>())) { + std::function<bool(const HloInstruction*)> pred = + [constant](const HloInstruction* instr) { + const auto& literal = Cast<HloConstantInstruction>(instr)->literal(); + auto status_or_const = LiteralUtil::CreateR0(constant).Convert( + literal.shape().element_type()); + return status_or_const.ok() && + literal == status_or_const.ConsumeValueOrDie(); + }; + + return Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate(std::move(pred)); + } +}; +} // namespace detail + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); @@ -1107,6 +1228,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } +template <typename T> +inline auto ConstantScalar(T constant) + -> decltype(detail::PatternFriend::ConstantScalar(constant)) { + return detail::PatternFriend::ConstantScalar(constant); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index a530581c34..b3a2c954b3 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -211,5 +211,89 @@ TEST(PatternMatcherTest, GetTupleElement) { EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); } +TEST(PatternMatcherTest, AnyOf) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE( + Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0), + match::ConstantScalar(1)))); + EXPECT_TRUE( + Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1), + match::ConstantScalar(0)))); + EXPECT_FALSE( + Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0), + match::ConstantScalar(2)))); +} + +TEST(PatternMatcherTest, ConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE(Match(root, match::ConstantScalar(42))); + EXPECT_FALSE(Match(root, match::ConstantScalar(41))); + EXPECT_FALSE(Match(root, match::ConstantScalar(0))); +} + +TEST(PatternMatcherTest, MultiplyAnyOrder) { + using match::ConstantScalar; + using match::MultiplyAnyOrder; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + const HloInstruction* instr; + + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); +} + +TEST(PatternMatcherTest, AnyOfShortCircuit) { + using match::AnyOf; + using match::Multiply; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any)))); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(nullptr, any); + } + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op())))); + EXPECT_NE(nullptr, any); + EXPECT_EQ(nullptr, mul); + } +} + } // namespace } // namespace xla |