aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-21 15:49:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 15:59:36 -0700
commitc734063a84b505f6423cb2d2b7147eaccee1a8f2 (patch)
tree92763a6d9909ecca0d7c5654b7099c95b4905beb
parentb28938c3672db9c23f84298160658787d0ccf69d (diff)
Extract out a MatrixMatrixBlockPanelEmitter::Dimensions struct; NFC
This gives me a convenient place to note that the m/k/n here are not the m/k/n for the entire GEMM. I didn't rename m/k/n to mc/kc/nr since the latter seems somewhat redundant to me -- we could read 'c as 'column' and 'r' as 'row', but that's the only possibility? This refactoring will also be useful when implementing GEPP on top of GEBP. PiperOrigin-RevId: 197474137
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc60
1 files changed, 38 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index a0b11461ff..af69fc3da9 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -532,6 +532,23 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
// matrices.
class MatrixMatrixBlockPanelEmitter {
public:
+ // Describe the dimensions of the GEBP kernel. These will usually not be the
+ // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP
+ // kernels with smaller dimensions.
+ class Dimensions {
+ public:
+ explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {}
+
+ int64 m() const { return m_; }
+ int64 k() const { return k_; }
+ int64 n() const { return n_; }
+
+ private:
+ const int64 m_;
+ const int64 k_;
+ const int64 n_;
+ };
+
// Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
// `lhs` with `rhs` and stores the result in `result`.
//
@@ -547,16 +564,14 @@ class MatrixMatrixBlockPanelEmitter {
// `k_tiling_factor` is the number of elements along the reduction dimensions
// that we will attempt to process at once.
explicit MatrixMatrixBlockPanelEmitter(
- llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, int64 m, int64 k,
- int64 n, int max_vectorization_width, int min_vectorization_width,
+ llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims,
+ int max_vectorization_width, int min_vectorization_width,
int k_tiling_factor, const TargetMachineFeatures& target_machine_features,
llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type)
: lhs_(lhs),
rhs_(rhs),
result_(result),
- m_(m),
- k_(k),
- n_(n),
+ dims_(dims),
max_vectorization_width_(max_vectorization_width),
min_vectorization_width_(min_vectorization_width),
k_tiling_factor_(k_tiling_factor),
@@ -598,9 +613,7 @@ class MatrixMatrixBlockPanelEmitter {
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* result_;
- int64 m_;
- int64 k_;
- int64 n_;
+ Dimensions dims_;
int64 max_vectorization_width_;
int64 min_vectorization_width_;
@@ -617,9 +630,9 @@ void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); }
void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
int64 current_vectorization_width = max_vectorization_width_;
int64 n_start = 0;
- while (n_start != n_ &&
+ while (n_start != dims_.n() &&
current_vectorization_width >= min_vectorization_width_) {
- int64 n_end = n_ - (n_ % current_vectorization_width);
+ int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width);
if (n_start != n_end) {
VectorSupportLibrary vsl(primitive_type_, current_vectorization_width,
ir_builder_, "gebp");
@@ -629,9 +642,9 @@ void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
current_vectorization_width /= 2;
}
- if (n_start != n_) {
+ if (n_start != dims_.n()) {
VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp");
- ksl_.For("epi.n", n_start, n_, 1, [&](llvm::Value* n_i) {
+ ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next =
ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
EmitLoopOverK(&vsl, n_i, n_i_next);
@@ -643,16 +656,16 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
llvm::Value* n_start,
llvm::Value* n_end) {
int64 k_start = 0;
- int64 k_end = k_ - (k_ % k_tiling_factor_);
+ int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_);
if (k_end != k_start) {
EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start,
n_end, vsl);
k_start = k_end;
}
- if (k_start != k_) {
- EmitInnerLoop(k_ - k_start, getInt64(k_start), getInt64(k_), n_start, n_end,
- vsl);
+ if (k_start != dims_.k()) {
+ EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()),
+ n_start, n_end, vsl);
}
}
@@ -709,12 +722,12 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end,
llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) {
- ksl_.For("dot.m", 0, m_, 1, [&](llvm::Value* m_i) {
+ ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) {
// This outer loop iterates over all of the M dimension
llvm::Value* result_row_begin = vsl->ComputeOffsetPointer(
- result_, /*offset_elements=*/m_i, /*scale=*/n_);
- llvm::Value* lhs_row_begin =
- vsl->ComputeOffsetPointer(lhs_, /*offset_elements=*/m_i, /*scale=*/k_);
+ result_, /*offset_elements=*/m_i, /*scale=*/dims_.n());
+ llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer(
+ lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k());
ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) {
// broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>,
@@ -728,7 +741,8 @@ void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
// rhs_loader will be used to load the tile off of the RHS, denoted as
// <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram.
- TileLoader rhs_loader(vsl, ir_builder_, rhs_, n_, k_i, k_tiling_factor);
+ TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i,
+ k_tiling_factor);
ksl_.For(
"dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
// This loop iterates over the N dimension. It loads the tile from
@@ -832,8 +846,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
target, ir_builder_->getInt8(0), size_bytes,
target_machine_features_.minimum_alignment_for_allocation(size_bytes));
+ MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k,
+ /*n=*/n);
MatrixMatrixBlockPanelEmitter gebp_emitter(
- /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, /*m=*/m, /*k=*/k, /*n=*/n,
+ /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims,
/*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
/*k_tiling_factor=*/8, target_machine_features_, ir_builder_,
primitive_type);