From c734063a84b505f6423cb2d2b7147eaccee1a8f2 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 21 May 2018 15:49:14 -0700 Subject: Extract out a MatrixMatrixBlockPanelEmitter::Dimensions struct; NFC This gives me a convenient place to note that the m/k/n here are not the m/k/n for the entire GEMM. I didn't rename m/k/n to mc/kc/nr since the latter seems somewhat redundant to me -- we could read 'c as 'column' and 'r' as 'row', but that's the only possibility? This refactoring will also be useful when implementing GEPP on top of GEBP. PiperOrigin-RevId: 197474137 --- .../compiler/xla/service/cpu/dot_op_emitter.cc | 60 ++++++++++++++-------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index a0b11461ff..af69fc3da9 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -532,6 +532,23 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // matrices. class MatrixMatrixBlockPanelEmitter { public: + // Describe the dimensions of the GEBP kernel. These will usually not be the + // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP + // kernels with smaller dimensions. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies // `lhs` with `rhs` and stores the result in `result`. // @@ -547,16 +564,14 @@ class MatrixMatrixBlockPanelEmitter { // `k_tiling_factor` is the number of elements along the reduction dimensions // that we will attempt to process at once. explicit MatrixMatrixBlockPanelEmitter( - llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, int64 m, int64 k, - int64 n, int max_vectorization_width, int min_vectorization_width, + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims, + int max_vectorization_width, int min_vectorization_width, int k_tiling_factor, const TargetMachineFeatures& target_machine_features, llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type) : lhs_(lhs), rhs_(rhs), result_(result), - m_(m), - k_(k), - n_(n), + dims_(dims), max_vectorization_width_(max_vectorization_width), min_vectorization_width_(min_vectorization_width), k_tiling_factor_(k_tiling_factor), @@ -598,9 +613,7 @@ class MatrixMatrixBlockPanelEmitter { llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* result_; - int64 m_; - int64 k_; - int64 n_; + Dimensions dims_; int64 max_vectorization_width_; int64 min_vectorization_width_; @@ -617,9 +630,9 @@ void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); } void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { int64 current_vectorization_width = max_vectorization_width_; int64 n_start = 0; - while (n_start != n_ && + while (n_start != dims_.n() && current_vectorization_width >= min_vectorization_width_) { - int64 n_end = n_ - (n_ % current_vectorization_width); + int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width); if (n_start != n_end) { VectorSupportLibrary vsl(primitive_type_, current_vectorization_width, ir_builder_, "gebp"); @@ -629,9 +642,9 @@ void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { current_vectorization_width /= 2; } - if (n_start != n_) { + if (n_start != dims_.n()) { VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp"); - ksl_.For("epi.n", n_start, n_, 1, [&](llvm::Value* n_i) { + ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); EmitLoopOverK(&vsl, n_i, n_i_next); @@ -643,16 +656,16 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, llvm::Value* n_end) { int64 k_start = 0; - int64 k_end = k_ - (k_ % k_tiling_factor_); + int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_); if (k_end != k_start) { EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start, n_end, vsl); k_start = k_end; } - if (k_start != k_) { - EmitInnerLoop(k_ - k_start, getInt64(k_start), getInt64(k_), n_start, n_end, - vsl); + if (k_start != dims_.k()) { + EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()), + n_start, n_end, vsl); } } @@ -709,12 +722,12 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) { - ksl_.For("dot.m", 0, m_, 1, [&](llvm::Value* m_i) { + ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) { // This outer loop iterates over all of the M dimension llvm::Value* result_row_begin = vsl->ComputeOffsetPointer( - result_, /*offset_elements=*/m_i, /*scale=*/n_); - llvm::Value* lhs_row_begin = - vsl->ComputeOffsetPointer(lhs_, /*offset_elements=*/m_i, /*scale=*/k_); + result_, /*offset_elements=*/m_i, /*scale=*/dims_.n()); + llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer( + lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k()); ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) { // broadcasted_a is the broadcasted set of vectors denoted as , @@ -728,7 +741,8 @@ void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( // rhs_loader will be used to load the tile off of the RHS, denoted as // <, ...> in the diagram. - TileLoader rhs_loader(vsl, ir_builder_, rhs_, n_, k_i, k_tiling_factor); + TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i, + k_tiling_factor); ksl_.For( "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { // This loop iterates over the N dimension. It loads the tile from @@ -832,8 +846,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( target, ir_builder_->getInt8(0), size_bytes, target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k, + /*n=*/n); MatrixMatrixBlockPanelEmitter gebp_emitter( - /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, /*m=*/m, /*k=*/k, /*n=*/n, + /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims, /*max_vectorization_width=*/8, /*min_vectorization_width=*/4, /*k_tiling_factor=*/8, target_machine_features_, ir_builder_, primitive_type); -- cgit v1.2.3