aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_matchers.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-24 07:06:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 07:09:20 -0700
commit5eb233d0686636a7bacc5b8813c079b6b9aa483c (patch)
tree5e2cbbe141fe5ba07e43d8765a9f7f2cee1af226 /tensorflow/compiler/xla/service/hlo_matchers.h
parentb9e12bc69df65eca279a90045d045e661fdb8108 (diff)
Introduce a new HLO shape and sharding matcher.
These new matchers can be used in tests in combination to the existing HLO opcode matchers to better verify a generated HLO graph. PiperOrigin-RevId: 194082100
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h69
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 103f04a2cb..f2ab9b5d9b 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
namespace testing {
@@ -86,6 +87,50 @@ class HloCustomCallMatcher : public HloMatcher {
::testing::Matcher<string> call_target_matcher_;
};
+class HloShapeMatcher
+ : public ::testing::MatcherInterface<const HloInstruction*> {
+ public:
+ explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ Shape shape_;
+};
+
+class HloShapeAndLayoutMatcher
+ : public ::testing::MatcherInterface<const HloInstruction*> {
+ public:
+ explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ Shape shape_;
+};
+
+// Verify the sharding of an instruction against the provided HloSharding. If a
+// nullopt is provided for the expected sharding then it checks that no sharding
+// is present for an instruction.
+class HloShardingMatcher
+ : public ::testing::MatcherInterface<const HloInstruction*> {
+ public:
+ explicit HloShardingMatcher(
+ const tensorflow::gtl::optional<HloSharding>& sharding)
+ : sharding_(sharding) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ tensorflow::gtl::optional<HloSharding> sharding_;
+};
+
// HloInstruction* matchers for opcode and operands. Example:
// namespace op = xla::opcode_matchers;
// EXPECT_THAT(instruction,
@@ -231,6 +276,30 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
}
+// Verifies the shape or the shape and the layout of an HLO instruction against
+// the provided shape object.
+inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
+ const class Shape& shape) {
+ return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
+}
+inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
+ const class Shape& shape) {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloShapeAndLayoutMatcher(shape));
+}
+
+// Verifies the value of the HloSharing against the provided sharding object.
+inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
+ const HloSharding& sharding) {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloShardingMatcher(sharding));
+}
+// Verifies that no HloSharding is set for an HLO instruction.
+inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt));
+}
+
#undef HLO_MATCHER
} // namespace opcode_matchers