diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-05-04 22:04:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 15:40:20 -0700 |
commit | 150089e6e67e4492f098cdd8f9f2f48dc9f9cc56 (patch) | |
tree | 778d8f20ab300ceea85a36c22d150570ff9530f8 /tensorflow/compiler/xla/service/hlo_matchers.h | |
parent | 939fc534a4b2f227ee337e7dcfa82ec9b6337814 (diff) |
Remove uses of the kTransposeDot fusion
I didn't remove the enum itself, but after this change removing the enum should
be a simple NFC change (famous last words!).
This will make it easier to implement BatchDot on CPU.
The change removes usages of kTransposeDot by:
- Teaching TransposeFolding to "fuse" transposes into dots by flipping the
lhs_contracting_dims/rhs_contracting_dims fields.
- Replacing the notion of transpose_lhs/transpose_rhs in the IR emitters with
"has a non-canonical LHS contraction dimension"/"has a non-canonical RHS
contraction dimension" where the canonical LHS and RHS contraction dims [0]
are 1 and 0.
Some tests were getting away with creating Dot instructions with their
dimensions numbers unset. I've fixed these to create canonical dot operations
instead.
It is possible (but hard to tell without trying) that some of the IR emission
logic and Eigen runtime calls can now be simplified further. For instance,
instead of passing in a `transpose_lhs` and `transpose_rhs` to the Eigen GEMM
routines, we could instead pass in the LHS and RHS contraction dimensions
directly.
[0] See HloInstruction::CreateCanonicalDot.
PiperOrigin-RevId: 195514907
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_matchers.h | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 5175736a25..75231beac7 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -131,6 +131,27 @@ class HloShardingMatcher tensorflow::gtl::optional<HloSharding> sharding_; }; +// Matches a Dot HLO instruction with specific LHS and RHS contracting +// dimensions. +class HloDotWithContractDimsMatcher : public HloMatcher { + public: + explicit HloDotWithContractDimsMatcher( + ::testing::Matcher<const HloInstruction*> lhs, + ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim, + int64 rhs_contracting_dim) + : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}), + lhs_contracting_dim_(lhs_contracting_dim), + rhs_contracting_dim_(rhs_contracting_dim) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + int64 lhs_contracting_dim_; + int64 rhs_contracting_dim_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -158,7 +179,6 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); -HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -310,6 +330,30 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() { new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); } +inline ::testing::Matcher<const ::xla::HloInstruction*> Dot( + ::testing::Matcher<const HloInstruction*> lhs_matcher, + ::testing::Matcher<const HloInstruction*> rhs_matcher) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher})); +} + +// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension +// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension +// equal to `rhs_contracting_dim`. +// +// Currently the HLO verifier rejects Dot operations with more than one +// contracting dimension (even though we can represent these in the +// DotDimensionNumbers proto) so there is no need to generalize this to support +// multiple contracting dimensions. +inline ::testing::Matcher<const ::xla::HloInstruction*> Dot( + ::testing::Matcher<const HloInstruction*> lhs_matcher, + ::testing::Matcher<const HloInstruction*> rhs_matcher, + int64 lhs_contracting_dim, int64 rhs_contracting_dim) { + return ::testing::MakeMatcher( + new ::xla::testing::HloDotWithContractDimsMatcher( + lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim)); +} + #undef HLO_MATCHER } // namespace opcode_matchers |