diff options
author | Justin Lebar <jlebar@google.com> | 2017-10-31 16:47:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-31 16:51:35 -0700 |
commit | 4aa90bfd39832570e84ab049f4c099359f2f608a (patch) | |
tree | 7ee074b2d8a0d76f839506415d5366d4546df160 /tensorflow/compiler/xla/service/hlo_matchers.cc | |
parent | f97e7c69b84dac8c3c8c78204d48816036b9bead (diff) |
[XLA] Add HLO matchers that check parameter numbers and GTE indices.
This lets you do
EXPECT_THAT(foo, op::Parameter(42));
and
EXPECT_THAT(bar, op::GetTupleElement(baz, 8));
PiperOrigin-RevId: 174113597
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_matchers.cc | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 0660d5a182..4255d60866 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -73,6 +73,35 @@ void HloMatcher::DescribeTo(::std::ostream* os) const { } } +bool HloParameterMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->parameter_number() != parameter_number_) { + *listener << "has wrong parameter number (got " + << instruction->parameter_number() << ", want " + << parameter_number_ << ")"; + return false; + } + return true; +} + +bool HloGetTupleElementMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->tuple_index() != tuple_index_) { + *listener << "has wrong tuple index (got " << instruction->tuple_index() + << ", want " << tuple_index_ << ")"; + return false; + } + return true; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { |