diff options
6 files changed, 328 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index f9c51f243c..e75fcb6bc9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -22,6 +22,8 @@ namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; +const char* const kXlaEnableExperimentalLlvmIrGemm = + "xla_enable_experimental_llvm_ir_gemm"; } // namespace @@ -54,6 +56,12 @@ tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor( return tensorflow::gtl::nullopt; } +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index be62ff3cc1..106dfbbc62 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,6 +26,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 5cdfc110af..a0b11461ff 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -521,6 +521,233 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } } +// This class implements a tiled matrix multiplication algorithm, intended for +// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, +// Kazushige, and Robert Van De Geijn. "High-performance implementation of the +// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): +// 4). +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class MatrixMatrixBlockPanelEmitter { + public: + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + // + // `m`, `k` and `n` are the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // `k_tiling_factor` is the number of elements along the reduction dimensions + // that we will attempt to process at once. + explicit MatrixMatrixBlockPanelEmitter( + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, int64 m, int64 k, + int64 n, int max_vectorization_width, int min_vectorization_width, + int k_tiling_factor, const TargetMachineFeatures& target_machine_features, + llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type) + : lhs_(lhs), + rhs_(rhs), + result_(result), + m_(m), + k_(k), + n_(n), + max_vectorization_width_(max_vectorization_width), + min_vectorization_width_(min_vectorization_width), + k_tiling_factor_(k_tiling_factor), + target_machine_features_(target_machine_features), + ir_builder_(ir_builder), + primitive_type_(primitive_type), + ksl_(ir_builder_) { + CHECK(max_vectorization_width > 0 && + IsPowerOfTwo(static_cast<uint64>(max_vectorization_width))); + CHECK(min_vectorization_width > 0 && + IsPowerOfTwo(static_cast<uint64>(min_vectorization_width))); + CHECK_GT(k_tiling_factor, 0); + } + + void Emit(); + + private: + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. This function emits that outermost loop. + void EmitChunkedLoopOverN(); + + // This emits a loop that loops over the `k` dimension in multiples of + // `k_tiling_factor` as much as possible and then emits a remainder epilogue. + void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + + // This emits the inner reduction loop. This inner reduction loop processes + // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension + // and [`n_start`, `n_end`) in the `n` dimension. + void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, + llvm::Value* n_end, VectorSupportLibrary* vsl); + + llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + int64 m_; + int64 k_; + int64 n_; + + int64 max_vectorization_width_; + int64 min_vectorization_width_; + int64 k_tiling_factor_; + + const TargetMachineFeatures& target_machine_features_; + llvm::IRBuilder<>* ir_builder_; + PrimitiveType primitive_type_; + KernelSupportLibrary ksl_; +}; + +void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); } + +void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { + int64 current_vectorization_width = max_vectorization_width_; + int64 n_start = 0; + while (n_start != n_ && + current_vectorization_width >= min_vectorization_width_) { + int64 n_end = n_ - (n_ % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(primitive_type_, current_vectorization_width, + ir_builder_, "gebp"); + EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end)); + n_start = n_end; + } + current_vectorization_width /= 2; + } + + if (n_start != n_) { + VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp"); + ksl_.For("epi.n", n_start, n_, 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = + ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + EmitLoopOverK(&vsl, n_i, n_i_next); + }); + } +} + +void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = k_ - (k_ % k_tiling_factor_); + if (k_end != k_start) { + EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start, + n_end, vsl); + k_start = k_end; + } + + if (k_start != k_) { + EmitInnerLoop(k_ - k_start, getInt64(k_start), getInt64(k_), n_start, n_end, + vsl); + } +} + +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +---+---+---+ +// | a | b | c | . +// +---+---+---+ . +// | | | | . +// +---+---+---+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let k_tiling_factor be 3 and the vector width (implicitly denoted by +// `vsl`) be 4. +// +// Then we +// +// 1. broadcast the first row in LHS to 3 vectors of width 4 +// 2. elementwise multiply the RHS rows with these broadcasted vectors +// 3. elementwise add them: +// +// +---+---+---+---+ +----+----+----+----+ +// | a | a | a | a | * | p0 | p1 | p2 | p3 | + +// +---+---+---+---+ +----+----+----+----+ +// +// +---+---+---+---+ +----+----+----+----+ +// | b | b | b | b | * | q0 | q1 | q2 | q3 | + +// +---+---+---+---+ +----+----+----+----+ +// +// +---+---+---+---+ +----+----+----+----+ +// | c | c | c | c | * | r0 | r1 | r2 | r3 | +// +---+---+---+---+ +----+----+----+----+ +// +// to get: +// +// +----------------+----------------+----------------+----------------+ +// | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 | +// +----------------+----------------+----------------+----------------+ +// +// which we increment into the appropriate region in the result. +void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( + int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) { + ksl_.For("dot.m", 0, m_, 1, [&](llvm::Value* m_i) { + // This outer loop iterates over all of the M dimension + llvm::Value* result_row_begin = vsl->ComputeOffsetPointer( + result_, /*offset_elements=*/m_i, /*scale=*/n_); + llvm::Value* lhs_row_begin = + vsl->ComputeOffsetPointer(lhs_, /*offset_elements=*/m_i, /*scale=*/k_); + + ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) { + // broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>, + // <b,b,b,b> etc. in the diagram. + std::vector<llvm::Value*> broadcasted_a; + broadcasted_a.reserve(k_tiling_factor); + for (int i = 0; i < k_tiling_factor; i++) { + broadcasted_a.push_back(vsl->LoadBroadcast( + lhs_row_begin, ir_builder_->CreateAdd(getInt64(i), k_i))); + } + + // rhs_loader will be used to load the tile off of the RHS, denoted as + // <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram. + TileLoader rhs_loader(vsl, ir_builder_, rhs_, n_, k_i, k_tiling_factor); + ksl_.For( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + // This loop iterates over the N dimension. It loads the tile from + // RHS, does the FMA resulting in the + // <a*p0+b*q0+c*r0,a*p1+b*q1+c*r1,...> in the diagram and increments + // the result. + std::vector<llvm::Value*> tile = rhs_loader.LoadTile(n_i); + llvm::Value* result_accumulator = + vsl->LoadVector(result_row_begin, n_i); + for (int i = 0; i < tile.size(); i++) { + result_accumulator = + vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator); + } + vsl->StoreVector(result_accumulator, result_row_begin, n_i); + }); + }); + }); +} + } // namespace DotOpEmitter::DotOpEmitter(const HloInstruction& dot, @@ -558,6 +785,62 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, return dot_emitter.Emit(); } +bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( + const DotOpEmitter::MatMultDims& mat_mult_dims) { + if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + return false; + } + + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + switch (primitive_type) { + default: + return false; + + case F32: + case F64: + case S32: + case S64: + break; + } + + if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && + mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { + return false; + } + + VLOG(2) << "Emitting GEBP kernel in LLVM IR"; + + llvm::Value* lhs = lhs_array_.GetBasePointer(); + llvm::Value* rhs = rhs_array_.GetBasePointer(); + llvm::Value* target = target_array_.GetBasePointer(); + int64 m = mat_mult_dims.m; + int64 k = mat_mult_dims.k; + int64 n = mat_mult_dims.n; + + if (mat_mult_dims.lhs_column_major) { + std::swap(lhs, rhs); + std::swap(m, n); + } + + int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + ir_builder_->CreateMemSet( + target, ir_builder_->getInt8(0), size_bytes, + target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + + MatrixMatrixBlockPanelEmitter gebp_emitter( + /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, /*m=*/m, /*k=*/k, /*n=*/n, + /*max_vectorization_width=*/8, /*min_vectorization_width=*/4, + /*k_tiling_factor=*/8, target_machine_features_, ir_builder_, + primitive_type); + gebp_emitter.Emit(); + return true; +} + bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { return false; @@ -610,7 +893,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return false; + return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -909,8 +1192,7 @@ Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = - hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -1019,7 +1301,9 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, - /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1}; + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1, + /*target_column_major=*/ + LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index a75b8ffcbf..d88ccea0db 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -123,6 +123,9 @@ class DotOpEmitter { // True if the RHS contraction dimension is not 0. bool rhs_non_canonical; + + // True if the result matrix is column major. + bool target_column_major; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -130,6 +133,8 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; + bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. int64 GetGemvTilingFactor() const { @@ -138,6 +143,17 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + // Returns true if we should use an experimental implementation of GEMM + // (general matrix matrix multiplication) if possible. + bool EnableExperimentalLlvmIrGemm() const { + return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); + } + + // Returns true if we should call into multi-threaded Eigen routines. + bool ShouldUseMultiThreadedEigen() { + return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + } + const HloInstruction& dot_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 6479bf76aa..edcaec5849 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -144,6 +144,12 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements, int64 scale) { + return ComputeOffsetPointer( + base_pointer, + ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + } + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { return ComputeOffsetPointer(base_pointer, ir_builder()->getInt64(offset_elements)); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 1c00b2aabd..64b935bbf1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -101,6 +101,15 @@ class KernelSupportLibrary { } void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function<void(llvm::Value* ind_var)>& for_body_generator) { + For(name, start, end, ir_builder_->getInt64(step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + } + + void For( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function<void(llvm::Value* ind_var)>& for_body_generator) { For(name, /*start=*/ir_builder_->getInt64(start), |