diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-24 07:06:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-24 07:09:20 -0700 |
commit | 5eb233d0686636a7bacc5b8813c079b6b9aa483c (patch) | |
tree | 5e2cbbe141fe5ba07e43d8765a9f7f2cee1af226 /tensorflow/compiler/xla/service/hlo_matchers.h | |
parent | b9e12bc69df65eca279a90045d045e661fdb8108 (diff) |
Introduce a new HLO shape and sharding matcher.
These new matchers can be used in tests in combination to the existing
HLO opcode matchers to better verify a generated HLO graph.
PiperOrigin-RevId: 194082100
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_matchers.h | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 103f04a2cb..f2ab9b5d9b 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace xla { namespace testing { @@ -86,6 +87,50 @@ class HloCustomCallMatcher : public HloMatcher { ::testing::Matcher<string> call_target_matcher_; }; +class HloShapeMatcher + : public ::testing::MatcherInterface<const HloInstruction*> { + public: + explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + Shape shape_; +}; + +class HloShapeAndLayoutMatcher + : public ::testing::MatcherInterface<const HloInstruction*> { + public: + explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + Shape shape_; +}; + +// Verify the sharding of an instruction against the provided HloSharding. If a +// nullopt is provided for the expected sharding then it checks that no sharding +// is present for an instruction. +class HloShardingMatcher + : public ::testing::MatcherInterface<const HloInstruction*> { + public: + explicit HloShardingMatcher( + const tensorflow::gtl::optional<HloSharding>& sharding) + : sharding_(sharding) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + tensorflow::gtl::optional<HloSharding> sharding_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -231,6 +276,30 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() { new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {})); } +// Verifies the shape or the shape and the layout of an HLO instruction against +// the provided shape object. +inline ::testing::Matcher<const ::xla::HloInstruction*> Shape( + const class Shape& shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); +} +inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout( + const class Shape& shape) { + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeAndLayoutMatcher(shape)); +} + +// Verifies the value of the HloSharing against the provided sharding object. +inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding( + const HloSharding& sharding) { + return ::testing::MakeMatcher( + new ::xla::testing::HloShardingMatcher(sharding)); +} +// Verifies that no HloSharding is set for an HLO instruction. +inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() { + return ::testing::MakeMatcher( + new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); +} + #undef HLO_MATCHER } // namespace opcode_matchers |