aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-01-25 14:56:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 15:01:19 -0800
commit022890f6ac03bb87cc7b4f1a5b722cd6b058e616 (patch)
treec659ea233a0cb0892b3a5ccb0a9185a07b0955f7
parentad5c04c9e1151c4de71288520d45f3b3142299fb (diff)
[XLA] Add HLO matcher for CustomCall that accepts a call target.
PiperOrigin-RevId: 183296506
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h53
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc33
3 files changed, 107 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 4255d60866..fe1bf61e97 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -102,6 +102,30 @@ bool HloGetTupleElementMatcher::MatchAndExplain(
return true;
}
+void HloCustomCallMatcher::DescribeTo(std::ostream* os) const {
+ HloMatcher::DescribeTo(os);
+ *os << " with call target that "
+ << ::testing::DescribeMatcher<string>(call_target_matcher_);
+}
+
+bool HloCustomCallMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (!HloMatcher::MatchAndExplain(instruction, listener)) {
+ return false;
+ }
+ ::testing::StringMatchResultListener sub_listener;
+ bool result = ExplainMatchResult(
+ call_target_matcher_, instruction->custom_call_target(), &sub_listener);
+ if (sub_listener.str().empty()) {
+ sub_listener << " that "
+ << ::testing::DescribeMatcher<string>(call_target_matcher_,
+ /*negation=*/!result);
+ }
+ *listener << "custom-call with call target" << sub_listener.str();
+ return result;
+}
+
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 9206cdac05..103f04a2cb 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -56,8 +56,8 @@ class HloParameterMatcher : public HloMatcher {
// index to match.
class HloGetTupleElementMatcher : public HloMatcher {
public:
- explicit HloGetTupleElementMatcher(
- ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index)
+ HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
+ int64 tuple_index)
: HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
tuple_index_(tuple_index) {}
@@ -68,6 +68,24 @@ class HloGetTupleElementMatcher : public HloMatcher {
int64 tuple_index_;
};
+// Custom matcher for custom-call instructions, which accepts a matcher for its
+// call target.
+class HloCustomCallMatcher : public HloMatcher {
+ public:
+ HloCustomCallMatcher(
+ ::testing::Matcher<string> call_target_matcher,
+ std::vector<::testing::Matcher<const HloInstruction*>> operands)
+ : HloMatcher(HloOpcode::kCustomCall, operands),
+ call_target_matcher_(call_target_matcher) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ ::testing::Matcher<string> call_target_matcher_;
+};
+
// HloInstruction* matchers for opcode and operands. Example:
// namespace op = xla::opcode_matchers;
// EXPECT_THAT(instruction,
@@ -94,7 +112,6 @@ HLO_MATCHER(Convert);
HLO_MATCHER(Convolution);
HLO_MATCHER(Copy);
HLO_MATCHER(CrossReplicaSum);
-HLO_MATCHER(CustomCall);
HLO_MATCHER(Divide);
HLO_MATCHER(Dot);
HLO_MATCHER(DynamicSlice);
@@ -184,6 +201,36 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
}
+// - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
+// target T and the given operands.
+//
+// - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
+// given operands.
+//
+// - CustomCall() matches any CustomCall HLO at all.
+template <typename... M>
+inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
+ ::testing::Matcher<string> call_target_matcher, M... operands) {
+ return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
+ call_target_matcher, {operands...}));
+}
+// This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
+// ::testing::Matcher<string>. In that case, we want to prefer the overload
+// above.
+template <typename FirstM, typename... M,
+ typename Dummy = typename std::enable_if<
+ !std::is_convertible<FirstM, ::testing::Matcher<string>>::value,
+ void>::type*>
+inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
+ FirstM operands_first, M... operands_rest) {
+ return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
+ HloOpcode::kCustomCall, {operands_first, operands_rest...}));
+}
+inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
+}
+
#undef HLO_MATCHER
} // namespace opcode_matchers
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 1465d1cacd..1c21703a45 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -23,6 +23,12 @@ using ::testing::Eq;
namespace xla {
namespace {
+string DescribeHloMatcher(const ::testing::Matcher<const HloInstruction*>& m) {
+ std::stringstream ss;
+ m.DescribeTo(&ss);
+ return ss.str();
+}
+
template <typename M, typename T>
string Explain(const T& t, const M& m) {
::testing::StringMatchResultListener listener;
@@ -67,5 +73,32 @@ TEST(HloMatchersTest, Test) {
"add"));
}
+TEST(HloMatchersTest, CustomCallMatcher) {
+ auto c1 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}));
+ auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3}));
+ auto call = HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target");
+
+ EXPECT_THAT(call.get(), op::CustomCall());
+ EXPECT_THAT(call.get(), op::CustomCall(c1.get(), c2.get()));
+ EXPECT_THAT(call.get(), op::CustomCall("foo_target"));
+ EXPECT_THAT(call.get(), op::CustomCall("foo_target", c1.get(), c2.get()));
+ EXPECT_THAT(call.get(), op::CustomCall(::testing::StartsWith("foo")));
+ EXPECT_THAT(call.get(),
+ op::CustomCall(::testing::Not(::testing::StartsWith("bar"))));
+
+ // Wrong number of operands.
+ EXPECT_THAT(call.get(), ::testing::Not(op::CustomCall(c1.get())));
+
+ // Call target does not match.
+ EXPECT_THAT(call.get(),
+ ::testing::Not(op::CustomCall(::testing::StartsWith("bar"))));
+
+ EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")),
+ R"(custom-call with call target that isn't equal to "bar")");
+ EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")),
+ R"(custom-call with call target that is equal to "foo_target")");
+}
+
} // namespace
} // namespace xla