aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/pattern_matcher.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/pattern_matcher.h')
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h78
1 files changed, 72 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 2515222cf2..ac6ea4c72f 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -86,8 +86,8 @@ namespace xla {
// are provided below.
//
// Example nullary instruction:
-// Recv() == Op().WithOpcode(HloOpcode::kRecv)
-// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv)
+// Param() == Op().WithOpcode(HloOpcode::kParam)
+// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam)
//
// Example unary instruction:
// Abs() == Op().WithOpcode(HloOpcode::kAbs)
@@ -726,6 +726,32 @@ class HloInstructionPatternFusionKindImpl {
::xla::HloInstruction::FusionKind kind_;
};
+// An HloInstructionPattern implementation that matches only if the instruction
+// is a kGetTupleElement with a particular tuple index.
+template <typename Previous>
+class HloInstructionPatternTupleIndexImpl {
+ public:
+ explicit constexpr HloInstructionPatternTupleIndexImpl(
+ const Previous& previous, int64 tuple_index)
+ : previous_(previous), tuple_index_(tuple_index) {}
+
+ bool Match(const ::xla::HloInstruction* inst) const {
+ return previous_.Match(inst) &&
+ inst->opcode() == HloOpcode::kGetTupleElement &&
+ inst->tuple_index() == tuple_index_;
+ }
+
+ bool Match(::xla::HloInstruction* inst) const {
+ return previous_.Match(inst) &&
+ inst->opcode() == HloOpcode::kGetTupleElement &&
+ inst->tuple_index() == tuple_index_;
+ }
+
+ private:
+ Previous previous_;
+ int64 tuple_index_;
+};
+
// A pattern that matches HloInstructions.
template <typename HloInstructionType, typename Impl>
class HloInstructionPattern {
@@ -841,6 +867,17 @@ class HloInstructionPattern {
HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_);
}
+ // Modifies the pattern to match only if the instruction is a
+ // get-tuple-element with the given tuple index.
+ constexpr HloInstructionPattern<HloInstructionType,
+ HloInstructionPatternTupleIndexImpl<Impl>>
+ WithTupleIndex(int64 tuple_index) const {
+ return HloInstructionPattern<HloInstructionType,
+ HloInstructionPatternTupleIndexImpl<Impl>>(
+ HloInstructionPatternTupleIndexImpl<Impl>(impl_, tuple_index),
+ matched_inst_);
+ }
+
private:
Impl impl_;
HloInstructionType** matched_inst_;
@@ -880,9 +917,7 @@ Op(::xla::HloInstruction** matched_inst) {
return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
}
XLA_NULLOP_PATTERN(Constant)
-XLA_NULLOP_PATTERN(Infeed)
XLA_NULLOP_PATTERN(Parameter)
-XLA_NULLOP_PATTERN(Recv)
#undef XLA_NULLOP_PATTERN
// Helpers for unary instructions.
@@ -919,18 +954,21 @@ XLA_UNOP_PATTERN(Cos)
XLA_UNOP_PATTERN(Exp)
XLA_UNOP_PATTERN(Fft)
XLA_UNOP_PATTERN(Floor)
+XLA_UNOP_PATTERN(GetTupleElement)
XLA_UNOP_PATTERN(Imag)
+XLA_UNOP_PATTERN(Infeed)
XLA_UNOP_PATTERN(IsFinite)
XLA_UNOP_PATTERN(Log)
XLA_UNOP_PATTERN(Not)
XLA_UNOP_PATTERN(Negate)
-XLA_UNOP_PATTERN(Outfeed)
XLA_UNOP_PATTERN(Real)
+XLA_UNOP_PATTERN(Recv)
+XLA_UNOP_PATTERN(RecvDone)
XLA_UNOP_PATTERN(Reduce)
XLA_UNOP_PATTERN(ReducePrecision)
XLA_UNOP_PATTERN(Reshape)
XLA_UNOP_PATTERN(Reverse)
-XLA_UNOP_PATTERN(Send)
+XLA_UNOP_PATTERN(SendDone)
XLA_UNOP_PATTERN(Sign)
XLA_UNOP_PATTERN(Sin)
XLA_UNOP_PATTERN(Sort)
@@ -981,8 +1019,10 @@ XLA_BINOP_PATTERN(Maximum)
XLA_BINOP_PATTERN(Minimum)
XLA_BINOP_PATTERN(Multiply)
XLA_BINOP_PATTERN(Ne)
+XLA_BINOP_PATTERN(Outfeed)
XLA_BINOP_PATTERN(Power)
XLA_BINOP_PATTERN(Remainder)
+XLA_BINOP_PATTERN(Send)
XLA_BINOP_PATTERN(Subtract)
XLA_BINOP_PATTERN(And)
XLA_BINOP_PATTERN(Or)
@@ -1040,6 +1080,32 @@ inline auto NonConstant(HloInstructionType** matched_inst)
return Op(matched_inst).IsNonConstant();
}
+// Add overloads for GetTupleElement which take a int64 specifying which tuple
+// element is selected.
+template <typename Arg>
+inline auto GetTupleElement(Arg&& arg, int64 tuple_index)
+ -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index)) {
+ return Op()
+ .WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index);
+}
+
+template <typename HloInstructionType, typename Arg>
+inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
+ int64 tuple_index)
+ -> decltype(Op(matched_inst)
+ .WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index)) {
+ return Op(matched_inst)
+ .WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index);
+}
+
} // namespace match
} // namespace xla