aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-18 16:58:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 17:06:21 -0700
commit93b5dea9663c00d3bb06348143b50b73b6fbacfb (patch)
treea87ca501119eb76a868fc9f83381e28a5369cfe0
parentf7b54ae1b4b215b2944e232ca51604aad1356930 (diff)
Add ConstantScalar, WithPredicate, Disjunction, and OpAnyOrder (where Op
is a commutative binary operator) to the XLA pattern matcher. PiperOrigin-RevId: 213543953
-rw-r--r--tensorflow/compiler/xla/service/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h143
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc84
3 files changed, 222 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fb80c78f68..68bf56c1b1 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -365,8 +365,11 @@ cc_library(
hdrs = ["pattern_matcher.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/utility",
],
)
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 4869db79e7..7d4d62ecb9 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -17,8 +17,12 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
#include "absl/strings/string_view.h"
+#include "absl/utility/utility.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -228,8 +232,46 @@ class LayoutPattern {
LayoutType** matched_layout_;
};
+template <typename Item, typename... Patterns>
+class AnyOfPattern {
+ public:
+ explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+ bool Match(const Item* item) const {
+ return MatchImpl(item, std::integral_constant<size_t, 0>());
+ }
+
+ bool Match(Item* item) const {
+ return MatchImpl(item, std::integral_constant<size_t, 0>());
+ }
+
+ private:
+ template <typename ItemType, size_t index>
+ bool MatchImpl(ItemType* item, std::integral_constant<size_t, index>) const {
+ return std::get<index>(patterns_).Match(item) ||
+ MatchImpl(item, std::integral_constant<size_t, index + 1>());
+ }
+
+ template <typename ItemType>
+ bool MatchImpl(ItemType* item,
+ std::integral_constant<size_t, sizeof...(Patterns)>) const {
+ return false;
+ }
+
+ std::tuple<Patterns...> patterns_;
+};
} // namespace detail
+// Returns a pattern that represents the logical disjunction of the input
+// patterns. The returned pattern matches from left to right, and stops on the
+// first match.
+template <typename Item, typename... Patterns>
+detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
+ const Patterns&... patterns) {
+ return detail::AnyOfPattern<typename std::remove_const<Item>::type,
+ Patterns...>(patterns...);
+}
+
// Creates a layout pattern that will capture the matched layout in the
// argument.
inline constexpr detail::LayoutPattern<const ::xla::Layout,
@@ -752,6 +794,27 @@ class HloInstructionPatternTupleIndexImpl {
int64 tuple_index_;
};
+template <typename Previous, typename ItemType, typename Predicate>
+class HloPredicatePatternImpl {
+ public:
+ explicit HloPredicatePatternImpl(const Previous& previous, Predicate pred)
+ : previous_(previous), pred_(std::move(pred)) {}
+
+ bool Match(const ItemType* item) const {
+ return previous_.Match(item) && pred_(item);
+ }
+
+ bool Match(ItemType* item) const {
+ return previous_.Match(item) && pred_(item);
+ }
+
+ private:
+ Previous previous_;
+ Predicate pred_;
+};
+
+struct PatternFriend;
+
// A pattern that matches HloInstructions.
template <typename HloInstructionType, typename Impl>
class HloInstructionPattern {
@@ -879,6 +942,21 @@ class HloInstructionPattern {
}
private:
+ template <typename Predicate>
+ constexpr HloInstructionPattern<
+ HloInstructionType,
+ HloPredicatePatternImpl<
+ Impl, typename std::remove_const<HloInstructionType>::type,
+ Predicate>>
+ WithPredicate(Predicate pred) const {
+ using NewImplType = HloPredicatePatternImpl<
+ Impl, typename std::remove_const<HloInstructionType>::type, Predicate>;
+ return HloInstructionPattern<HloInstructionType, NewImplType>(
+ NewImplType(impl_, std::move(pred)), matched_inst_);
+ }
+
+ friend struct PatternFriend;
+
Impl impl_;
HloInstructionType** matched_inst_;
};
@@ -1005,31 +1083,50 @@ XLA_UNOP_PATTERN(Transpose)
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)); \
}
-XLA_BINOP_PATTERN(Add)
+
+#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
+ XLA_BINOP_PATTERN(NAME) \
+ \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) { \
+ return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs)); \
+ } \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
+ Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs))) { \
+ return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs)); \
+ }
+XLA_COMMUTATIVE_BINOP_PATTERN(Add)
XLA_BINOP_PATTERN(Atan2)
XLA_BINOP_PATTERN(Divide)
XLA_BINOP_PATTERN(Complex)
XLA_BINOP_PATTERN(Dot)
-XLA_BINOP_PATTERN(Eq)
+XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
XLA_BINOP_PATTERN(Gather)
XLA_BINOP_PATTERN(Ge)
XLA_BINOP_PATTERN(Gt)
XLA_BINOP_PATTERN(Le)
XLA_BINOP_PATTERN(Lt)
-XLA_BINOP_PATTERN(Maximum)
-XLA_BINOP_PATTERN(Minimum)
-XLA_BINOP_PATTERN(Multiply)
-XLA_BINOP_PATTERN(Ne)
+XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
+XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
XLA_BINOP_PATTERN(Outfeed)
XLA_BINOP_PATTERN(Power)
XLA_BINOP_PATTERN(Remainder)
XLA_BINOP_PATTERN(Send)
XLA_BINOP_PATTERN(Subtract)
-XLA_BINOP_PATTERN(And)
-XLA_BINOP_PATTERN(Or)
+XLA_COMMUTATIVE_BINOP_PATTERN(And)
+XLA_COMMUTATIVE_BINOP_PATTERN(Or)
XLA_BINOP_PATTERN(ShiftLeft)
XLA_BINOP_PATTERN(ShiftRightArithmetic)
XLA_BINOP_PATTERN(ShiftRightLogical)
+#undef XLA_COMMUTATIVE_BINOP_PATTERN
#undef XLA_BINOP_PATTERN
// Helpers for ternary instructions.
@@ -1070,6 +1167,30 @@ XLA_TERNOP_PATTERN(Clamp);
XLA_TERNOP_PATTERN(Select);
#undef XLA_TERNOP_PATTERN
+namespace detail {
+struct PatternFriend {
+ template <typename T>
+ static auto ConstantScalar(T constant) -> decltype(
+ Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(
+ std::declval<std::function<bool(const HloInstruction*)>>())) {
+ std::function<bool(const HloInstruction*)> pred =
+ [constant](const HloInstruction* instr) {
+ const auto& literal = Cast<HloConstantInstruction>(instr)->literal();
+ auto status_or_const = LiteralUtil::CreateR0(constant).Convert(
+ literal.shape().element_type());
+ return status_or_const.ok() &&
+ literal == status_or_const.ConsumeValueOrDie();
+ };
+
+ return Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(std::move(pred));
+ }
+};
+} // namespace detail
+
// Helpers for matching non-constant instructions.
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
return Op().IsNonConstant();
@@ -1107,6 +1228,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
.WithTupleIndex(tuple_index);
}
+template <typename T>
+inline auto ConstantScalar(T constant)
+ -> decltype(detail::PatternFriend::ConstantScalar(constant)) {
+ return detail::PatternFriend::ConstantScalar(constant);
+}
+
} // namespace match
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index a530581c34..b3a2c954b3 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -211,5 +211,89 @@ TEST(PatternMatcherTest, GetTupleElement) {
EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
}
+TEST(PatternMatcherTest, AnyOf) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(1))));
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
+ match::ConstantScalar(0))));
+ EXPECT_FALSE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(2))));
+}
+
+TEST(PatternMatcherTest, ConstantScalar) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(Match(root, match::ConstantScalar(42)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(41)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(0)));
+}
+
+TEST(PatternMatcherTest, MultiplyAnyOrder) {
+ using match::ConstantScalar;
+ using match::MultiplyAnyOrder;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+ const HloInstruction* instr;
+
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
+}
+
+TEST(PatternMatcherTest, AnyOfShortCircuit) {
+ using match::AnyOf;
+ using match::Multiply;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
+ EXPECT_NE(nullptr, mul);
+ EXPECT_EQ(nullptr, any);
+ }
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
+ EXPECT_NE(nullptr, any);
+ EXPECT_EQ(nullptr, mul);
+ }
+}
+
} // namespace
} // namespace xla