aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc50
1 files changed, 23 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index d77076546f..c5c95a3c2c 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -729,21 +729,17 @@ class MatrixMatrixBlockPanelEmitter {
void Emit();
private:
- // This emits a loop that loops over the `n` dimension in multiples of
- // `max_vectorization_width` as much as possible and then emits a remainder
- // epilogue.
- void EmitLoopOverN();
-
- // This emits a loop that loops over the `k` dimension in multiples of
- // `tile_size_k` as much as possible and then emits a remainder epilogue.
- void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start,
- llvm::Value* n_end);
-
- // This emits a loop that loops over the `m` dimension in multiples of
- // `tile_size_m` as much as possible and then emits a remainder epilogue.
- void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k,
- llvm::Value* k_start, llvm::Value* k_end,
- llvm::Value* n_start, llvm::Value* n_end);
+ // The HandleResiduesOnX helpers split the iteration space for dimension X
+ // into a multiple of the tile size on dimension X and an epilogue. These
+ // helpers ultimately call into `EmitTiledReductionLoop` for emitting the
+ // tiled GEMM kernel.
+
+ void HandleResiduesOnN();
+ void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
+ llvm::Value* n_end);
+ void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k,
+ llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end);
// This emits the inner reduction loop. This inner reduction loop multiplies
// a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the
@@ -779,9 +775,9 @@ class MatrixMatrixBlockPanelEmitter {
KernelSupportLibrary ksl_;
};
-void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); }
+void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); }
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
// We can only iterate the `n` dimension for an extent that is divisible by
// the vectorization width. So we emit an outer loop that first processes the
// largest extent in `n` that is divisible by max_vectorization_width, then
@@ -796,7 +792,7 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
if (n_start != n_end) {
VectorSupportLibrary vsl(scalar_type(), current_vectorization_width,
ir_builder_, "gebp");
- EmitLoopOverK(&vsl, GetInt64(n_start), GetInt64(n_end));
+ HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
current_vectorization_width /= 2;
@@ -807,29 +803,29 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next =
ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
- EmitLoopOverK(&vsl, n_i, n_i_next);
+ HandleResiduesOnK(&vsl, n_i, n_i_next);
});
}
}
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
- llvm::Value* n_start,
- llvm::Value* n_end) {
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
+ llvm::Value* n_start,
+ llvm::Value* n_end) {
int64 k_start = 0;
int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
- EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
- n_start, n_end);
+ HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
+ n_start, n_end);
k_start = k_end;
}
if (k_start != dims().k()) {
- EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start),
- GetInt64(dims().k()), n_start, n_end);
+ HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
+ GetInt64(dims().k()), n_start, n_end);
}
}
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
const int64 m_end = dims().m() - dims().m() % tile_size_m();