aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_matchers.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-01-26 15:54:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 16:01:55 -0800
commit1d1a50e3a5f0e297e6d4d480cf28ca5be51d7c73 (patch)
tree6cadd13c3a246b7f24839b2463c00bfbd9bbf31e /tensorflow/compiler/xla/service/hlo_matchers.cc
parent3e9bf0874ed19b1f96f835c444a4b80167de4663 (diff)
[XLA] (Re-land) Add HLO matcher for CustomCall that accepts a call target.
Now with less build breakage! PiperOrigin-RevId: 183458987
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 4255d60866..bc74c4bc10 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -102,6 +102,36 @@ bool HloGetTupleElementMatcher::MatchAndExplain(
return true;
}
+void HloCustomCallMatcher::DescribeTo(std::ostream* os) const {
+ HloMatcher::DescribeTo(os);
+ *os << " with call target that ";
+ call_target_matcher_.DescribeTo(os);
+}
+
+bool HloCustomCallMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (!HloMatcher::MatchAndExplain(instruction, listener)) {
+ return false;
+ }
+ ::testing::StringMatchResultListener sub_listener;
+ bool result = ExplainMatchResult(
+ call_target_matcher_, instruction->custom_call_target(), &sub_listener);
+ if (sub_listener.str().empty()) {
+ sub_listener << " that ";
+
+ std::stringstream desc_stream;
+ if (result) {
+ call_target_matcher_.DescribeTo(&desc_stream);
+ } else {
+ call_target_matcher_.DescribeNegationTo(&desc_stream);
+ }
+ sub_listener << desc_stream.str();
+ }
+ *listener << "custom-call with call target" << sub_listener.str();
+ return result;
+}
+
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {