aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_matchers.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-04 22:04:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 15:40:20 -0700
commit150089e6e67e4492f098cdd8f9f2f48dc9f9cc56 (patch)
tree778d8f20ab300ceea85a36c22d150570ff9530f8 /tensorflow/compiler/xla/service/hlo_matchers.h
parent939fc534a4b2f227ee337e7dcfa82ec9b6337814 (diff)
Remove uses of the kTransposeDot fusion
I didn't remove the enum itself, but after this change removing the enum should be a simple NFC change (famous last words!). This will make it easier to implement BatchDot on CPU. The change removes usages of kTransposeDot by: - Teaching TransposeFolding to "fuse" transposes into dots by flipping the lhs_contracting_dims/rhs_contracting_dims fields. - Replacing the notion of transpose_lhs/transpose_rhs in the IR emitters with "has a non-canonical LHS contraction dimension"/"has a non-canonical RHS contraction dimension" where the canonical LHS and RHS contraction dims [0] are 1 and 0. Some tests were getting away with creating Dot instructions with their dimensions numbers unset. I've fixed these to create canonical dot operations instead. It is possible (but hard to tell without trying) that some of the IR emission logic and Eigen runtime calls can now be simplified further. For instance, instead of passing in a `transpose_lhs` and `transpose_rhs` to the Eigen GEMM routines, we could instead pass in the LHS and RHS contraction dimensions directly. [0] See HloInstruction::CreateCanonicalDot. PiperOrigin-RevId: 195514907
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h46
1 files changed, 45 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 5175736a25..75231beac7 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -131,6 +131,27 @@ class HloShardingMatcher
tensorflow::gtl::optional<HloSharding> sharding_;
};
+// Matches a Dot HLO instruction with specific LHS and RHS contracting
+// dimensions.
+class HloDotWithContractDimsMatcher : public HloMatcher {
+ public:
+ explicit HloDotWithContractDimsMatcher(
+ ::testing::Matcher<const HloInstruction*> lhs,
+ ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim,
+ int64 rhs_contracting_dim)
+ : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
+ lhs_contracting_dim_(lhs_contracting_dim),
+ rhs_contracting_dim_(rhs_contracting_dim) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ int64 lhs_contracting_dim_;
+ int64 rhs_contracting_dim_;
+};
+
// HloInstruction* matchers for opcode and operands. Example:
// namespace op = xla::opcode_matchers;
// EXPECT_THAT(instruction,
@@ -158,7 +179,6 @@ HLO_MATCHER(Convolution);
HLO_MATCHER(Copy);
HLO_MATCHER(CrossReplicaSum);
HLO_MATCHER(Divide);
-HLO_MATCHER(Dot);
HLO_MATCHER(DynamicSlice);
HLO_MATCHER(DynamicUpdateSlice);
HLO_MATCHER(Eq);
@@ -310,6 +330,30 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt));
}
+inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
+ ::testing::Matcher<const HloInstruction*> lhs_matcher,
+ ::testing::Matcher<const HloInstruction*> rhs_matcher) {
+ return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
+ ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
+}
+
+// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
+// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
+// equal to `rhs_contracting_dim`.
+//
+// Currently the HLO verifier rejects Dot operations with more than one
+// contracting dimension (even though we can represent these in the
+// DotDimensionNumbers proto) so there is no need to generalize this to support
+// multiple contracting dimensions.
+inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
+ ::testing::Matcher<const HloInstruction*> lhs_matcher,
+ ::testing::Matcher<const HloInstruction*> rhs_matcher,
+ int64 lhs_contracting_dim, int64 rhs_contracting_dim) {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloDotWithContractDimsMatcher(
+ lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
+}
+
#undef HLO_MATCHER
} // namespace opcode_matchers