aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_matchers.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-10-31 16:47:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-31 16:51:35 -0700
commit4aa90bfd39832570e84ab049f4c099359f2f608a (patch)
tree7ee074b2d8a0d76f839506415d5366d4546df160 /tensorflow/compiler/xla/service/hlo_matchers.cc
parentf97e7c69b84dac8c3c8c78204d48816036b9bead (diff)
[XLA] Add HLO matchers that check parameter numbers and GTE indices.
This lets you do EXPECT_THAT(foo, op::Parameter(42)); and EXPECT_THAT(bar, op::GetTupleElement(baz, 8)); PiperOrigin-RevId: 174113597
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 0660d5a182..4255d60866 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -73,6 +73,35 @@ void HloMatcher::DescribeTo(::std::ostream* os) const {
}
}
+bool HloParameterMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (!HloMatcher::MatchAndExplain(instruction, listener)) {
+ return false;
+ }
+ if (instruction->parameter_number() != parameter_number_) {
+ *listener << "has wrong parameter number (got "
+ << instruction->parameter_number() << ", want "
+ << parameter_number_ << ")";
+ return false;
+ }
+ return true;
+}
+
+bool HloGetTupleElementMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (!HloMatcher::MatchAndExplain(instruction, listener)) {
+ return false;
+ }
+ if (instruction->tuple_index() != tuple_index_) {
+ *listener << "has wrong tuple index (got " << instruction->tuple_index()
+ << ", want " << tuple_index_ << ")";
+ return false;
+ }
+ return true;
+}
+
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {