diff options
author | Justin Lebar <jlebar@google.com> | 2018-01-26 15:54:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-26 16:01:55 -0800 |
commit | 1d1a50e3a5f0e297e6d4d480cf28ca5be51d7c73 (patch) | |
tree | 6cadd13c3a246b7f24839b2463c00bfbd9bbf31e /tensorflow/compiler/xla/service/hlo_matchers.cc | |
parent | 3e9bf0874ed19b1f96f835c444a4b80167de4663 (diff) |
[XLA] (Re-land) Add HLO matcher for CustomCall that accepts a call target.
Now with less build breakage!
PiperOrigin-RevId: 183458987
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_matchers.cc | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 4255d60866..bc74c4bc10 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -102,6 +102,36 @@ bool HloGetTupleElementMatcher::MatchAndExplain( return true; } +void HloCustomCallMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with call target that "; + call_target_matcher_.DescribeTo(os); +} + +bool HloCustomCallMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + ::testing::StringMatchResultListener sub_listener; + bool result = ExplainMatchResult( + call_target_matcher_, instruction->custom_call_target(), &sub_listener); + if (sub_listener.str().empty()) { + sub_listener << " that "; + + std::stringstream desc_stream; + if (result) { + call_target_matcher_.DescribeTo(&desc_stream); + } else { + call_target_matcher_.DescribeNegationTo(&desc_stream); + } + sub_listener << desc_stream.str(); + } + *listener << "custom-call with call target" << sub_listener.str(); + return result; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { |