aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h1
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc292
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h16
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.h6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h9
6 files changed, 328 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index f9c51f243c..e75fcb6bc9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -22,6 +22,8 @@ namespace {
const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
+const char* const kXlaEnableExperimentalLlvmIrGemm =
+ "xla_enable_experimental_llvm_ir_gemm";
} // namespace
@@ -54,6 +56,12 @@ tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
return tensorflow::gtl::nullopt;
}
+bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
+ const auto& extra_options_map =
+ config.debug_options().xla_backend_extra_options();
+ return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0;
+}
+
} // namespace options
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index be62ff3cc1..106dfbbc62 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -26,6 +26,7 @@ namespace options {
bool OptimizeForSizeRequested(const HloModuleConfig& config);
bool VectorizedReduceDisabled(const HloModuleConfig& config);
+bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
const HloModuleConfig& config);
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 5cdfc110af..a0b11461ff 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -521,6 +521,233 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
}
}
+// This class implements a tiled matrix multiplication algorithm, intended for
+// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto,
+// Kazushige, and Robert Van De Geijn. "High-performance implementation of the
+// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008):
+// 4).
+//
+// This only supports canonical dot operations (i.e. where the lhs contraction
+// dimension is 1 and the rhs contraction dimension is 0) over row major
+// matrices.
+class MatrixMatrixBlockPanelEmitter {
+ public:
+ // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
+ // `lhs` with `rhs` and stores the result in `result`.
+ //
+ // `m`, `k` and `n` are the matrix multiplication dimensions.
+ //
+ // `max_vectorization_width` is the maximum vector width (i.e. the width of
+ // the largest vector register we will use). This can be larger than the
+ // largest vector register supported by the machine -- LLVM will legalize
+ // these large vector widths into legally sized vectors.
+ // `min_vectorization_width` is the smallest vector width the emitter will use
+ // -- below that it will devolve to using a scalar loop.
+ //
+ // `k_tiling_factor` is the number of elements along the reduction dimensions
+ // that we will attempt to process at once.
+ explicit MatrixMatrixBlockPanelEmitter(
+ llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, int64 m, int64 k,
+ int64 n, int max_vectorization_width, int min_vectorization_width,
+ int k_tiling_factor, const TargetMachineFeatures& target_machine_features,
+ llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type)
+ : lhs_(lhs),
+ rhs_(rhs),
+ result_(result),
+ m_(m),
+ k_(k),
+ n_(n),
+ max_vectorization_width_(max_vectorization_width),
+ min_vectorization_width_(min_vectorization_width),
+ k_tiling_factor_(k_tiling_factor),
+ target_machine_features_(target_machine_features),
+ ir_builder_(ir_builder),
+ primitive_type_(primitive_type),
+ ksl_(ir_builder_) {
+ CHECK(max_vectorization_width > 0 &&
+ IsPowerOfTwo(static_cast<uint64>(max_vectorization_width)));
+ CHECK(min_vectorization_width > 0 &&
+ IsPowerOfTwo(static_cast<uint64>(min_vectorization_width)));
+ CHECK_GT(k_tiling_factor, 0);
+ }
+
+ void Emit();
+
+ private:
+ // We can only iterate the `n` dimension for an extent that is divisible by
+ // the vectorization width. So we emit an outer loop that first processes the
+ // largest extent in `n` that is divisible by max_vectorization_width, then
+ // the largest remaining extent that is divisible by max_vectorization_width /
+ // 2 etc. This function emits that outermost loop.
+ void EmitChunkedLoopOverN();
+
+ // This emits a loop that loops over the `k` dimension in multiples of
+ // `k_tiling_factor` as much as possible and then emits a remainder epilogue.
+ void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start,
+ llvm::Value* n_end);
+
+ // This emits the inner reduction loop. This inner reduction loop processes
+ // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension
+ // and [`n_start`, `n_end`) in the `n` dimension.
+ void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start,
+ llvm::Value* k_end, llvm::Value* n_start,
+ llvm::Value* n_end, VectorSupportLibrary* vsl);
+
+ llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); }
+
+ llvm::Value* lhs_;
+ llvm::Value* rhs_;
+ llvm::Value* result_;
+ int64 m_;
+ int64 k_;
+ int64 n_;
+
+ int64 max_vectorization_width_;
+ int64 min_vectorization_width_;
+ int64 k_tiling_factor_;
+
+ const TargetMachineFeatures& target_machine_features_;
+ llvm::IRBuilder<>* ir_builder_;
+ PrimitiveType primitive_type_;
+ KernelSupportLibrary ksl_;
+};
+
+void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); }
+
+void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
+ int64 current_vectorization_width = max_vectorization_width_;
+ int64 n_start = 0;
+ while (n_start != n_ &&
+ current_vectorization_width >= min_vectorization_width_) {
+ int64 n_end = n_ - (n_ % current_vectorization_width);
+ if (n_start != n_end) {
+ VectorSupportLibrary vsl(primitive_type_, current_vectorization_width,
+ ir_builder_, "gebp");
+ EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end));
+ n_start = n_end;
+ }
+ current_vectorization_width /= 2;
+ }
+
+ if (n_start != n_) {
+ VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp");
+ ksl_.For("epi.n", n_start, n_, 1, [&](llvm::Value* n_i) {
+ llvm::Value* n_i_next =
+ ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
+ EmitLoopOverK(&vsl, n_i, n_i_next);
+ });
+ }
+}
+
+void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
+ llvm::Value* n_start,
+ llvm::Value* n_end) {
+ int64 k_start = 0;
+ int64 k_end = k_ - (k_ % k_tiling_factor_);
+ if (k_end != k_start) {
+ EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start,
+ n_end, vsl);
+ k_start = k_end;
+ }
+
+ if (k_start != k_) {
+ EmitInnerLoop(k_ - k_start, getInt64(k_start), getInt64(k_), n_start, n_end,
+ vsl);
+ }
+}
+
+// The tiling scheme is as follows:
+//
+// Let the LHS be:
+//
+// +---+---+---+
+// | a | b | c | .
+// +---+---+---+ .
+// | | | | .
+// +---+---+---+
+// .. ..
+//
+// and the RHS be:
+//
+// +----+----+----+----+
+// | p0 | p1 | p2 | p3 | .
+// +----+----+----+----+ .
+// | q0 | q1 | q2 | q3 | .
+// +----+----+----+----+
+// | r0 | r1 | r2 | r3 | .
+// +----+----+----+----+ .
+// ...... ......
+//
+// and let k_tiling_factor be 3 and the vector width (implicitly denoted by
+// `vsl`) be 4.
+//
+// Then we
+//
+// 1. broadcast the first row in LHS to 3 vectors of width 4
+// 2. elementwise multiply the RHS rows with these broadcasted vectors
+// 3. elementwise add them:
+//
+// +---+---+---+---+ +----+----+----+----+
+// | a | a | a | a | * | p0 | p1 | p2 | p3 | +
+// +---+---+---+---+ +----+----+----+----+
+//
+// +---+---+---+---+ +----+----+----+----+
+// | b | b | b | b | * | q0 | q1 | q2 | q3 | +
+// +---+---+---+---+ +----+----+----+----+
+//
+// +---+---+---+---+ +----+----+----+----+
+// | c | c | c | c | * | r0 | r1 | r2 | r3 |
+// +---+---+---+---+ +----+----+----+----+
+//
+// to get:
+//
+// +----------------+----------------+----------------+----------------+
+// | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 |
+// +----------------+----------------+----------------+----------------+
+//
+// which we increment into the appropriate region in the result.
+void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
+ int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) {
+ ksl_.For("dot.m", 0, m_, 1, [&](llvm::Value* m_i) {
+ // This outer loop iterates over all of the M dimension
+ llvm::Value* result_row_begin = vsl->ComputeOffsetPointer(
+ result_, /*offset_elements=*/m_i, /*scale=*/n_);
+ llvm::Value* lhs_row_begin =
+ vsl->ComputeOffsetPointer(lhs_, /*offset_elements=*/m_i, /*scale=*/k_);
+
+ ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) {
+ // broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>,
+ // <b,b,b,b> etc. in the diagram.
+ std::vector<llvm::Value*> broadcasted_a;
+ broadcasted_a.reserve(k_tiling_factor);
+ for (int i = 0; i < k_tiling_factor; i++) {
+ broadcasted_a.push_back(vsl->LoadBroadcast(
+ lhs_row_begin, ir_builder_->CreateAdd(getInt64(i), k_i)));
+ }
+
+ // rhs_loader will be used to load the tile off of the RHS, denoted as
+ // <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram.
+ TileLoader rhs_loader(vsl, ir_builder_, rhs_, n_, k_i, k_tiling_factor);
+ ksl_.For(
+ "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
+ // This loop iterates over the N dimension. It loads the tile from
+ // RHS, does the FMA resulting in the
+ // <a*p0+b*q0+c*r0,a*p1+b*q1+c*r1,...> in the diagram and increments
+ // the result.
+ std::vector<llvm::Value*> tile = rhs_loader.LoadTile(n_i);
+ llvm::Value* result_accumulator =
+ vsl->LoadVector(result_row_begin, n_i);
+ for (int i = 0; i < tile.size(); i++) {
+ result_accumulator =
+ vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator);
+ }
+ vsl->StoreVector(result_accumulator, result_row_begin, n_i);
+ });
+ });
+ });
+}
+
} // namespace
DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
@@ -558,6 +785,62 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
return dot_emitter.Emit();
}
+bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
+ const DotOpEmitter::MatMultDims& mat_mult_dims) {
+ if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) {
+ return false;
+ }
+
+ if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) {
+ return false;
+ }
+
+ PrimitiveType primitive_type = dot_.shape().element_type();
+
+ switch (primitive_type) {
+ default:
+ return false;
+
+ case F32:
+ case F64:
+ case S32:
+ case S64:
+ break;
+ }
+
+ if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major &&
+ mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) {
+ return false;
+ }
+
+ VLOG(2) << "Emitting GEBP kernel in LLVM IR";
+
+ llvm::Value* lhs = lhs_array_.GetBasePointer();
+ llvm::Value* rhs = rhs_array_.GetBasePointer();
+ llvm::Value* target = target_array_.GetBasePointer();
+ int64 m = mat_mult_dims.m;
+ int64 k = mat_mult_dims.k;
+ int64 n = mat_mult_dims.n;
+
+ if (mat_mult_dims.lhs_column_major) {
+ std::swap(lhs, rhs);
+ std::swap(m, n);
+ }
+
+ int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
+ ir_builder_->CreateMemSet(
+ target, ir_builder_->getInt8(0), size_bytes,
+ target_machine_features_.minimum_alignment_for_allocation(size_bytes));
+
+ MatrixMatrixBlockPanelEmitter gebp_emitter(
+ /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, /*m=*/m, /*k=*/k, /*n=*/n,
+ /*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
+ /*k_tiling_factor=*/8, target_machine_features_, ir_builder_,
+ primitive_type);
+ gebp_emitter.Emit();
+ return true;
+}
+
bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
if (dot_.shape().dimensions_size() != 2) {
return false;
@@ -610,7 +893,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
}
if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
- return false;
+ return EmitExperimentalGebpDotIfEnabled(mat_mult_dims);
}
int64 tiling_factor = GetGemvTilingFactor();
@@ -909,8 +1192,7 @@ Status DotOpEmitter::EmitCallToRuntime() {
// The two transpose_... parameters are actually booleans, but we use int32
// to avoid target-dependent calling convention details.
- bool multi_threaded =
- hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
+ bool multi_threaded = ShouldUseMultiThreadedEigen();
bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
PrimitiveType type = target_array_.GetShape().element_type();
llvm::Type* float_type;
@@ -1019,7 +1301,9 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
/*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0,
/*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0,
/*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0,
- /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1};
+ /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1,
+ /*target_column_major=*/
+ LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0};
}
llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index a75b8ffcbf..d88ccea0db 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -123,6 +123,9 @@ class DotOpEmitter {
// True if the RHS contraction dimension is not 0.
bool rhs_non_canonical;
+
+ // True if the result matrix is column major.
+ bool target_column_major;
};
// Get the MatMultDims instance for the dot product this DotOpEmitter
@@ -130,6 +133,8 @@ class DotOpEmitter {
// of rank 2 as well).
MatMultDims GetMatMultDims() const;
+ bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims);
+
// When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
// registers.
int64 GetGemvTilingFactor() const {
@@ -138,6 +143,17 @@ class DotOpEmitter {
.value_or(kDefaultTilingFactor);
}
+ // Returns true if we should use an experimental implementation of GEMM
+ // (general matrix matrix multiplication) if possible.
+ bool EnableExperimentalLlvmIrGemm() const {
+ return options::EnableExperimentalLlvmIrGemm(hlo_module_config_);
+ }
+
+ // Returns true if we should call into multi-threaded Eigen routines.
+ bool ShouldUseMultiThreadedEigen() {
+ return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
+ }
+
const HloInstruction& dot_;
const llvm_ir::IrArray& target_array_;
const llvm_ir::IrArray& lhs_array_;
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index 6479bf76aa..edcaec5849 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -144,6 +144,12 @@ class VectorSupportLibrary {
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
llvm::Value* offset_elements);
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
+ llvm::Value* offset_elements, int64 scale) {
+ return ComputeOffsetPointer(
+ base_pointer,
+ ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements));
+ }
+ llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
int64 offset_elements) {
return ComputeOffsetPointer(base_pointer,
ir_builder()->getInt64(offset_elements));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index 1c00b2aabd..64b935bbf1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -101,6 +101,15 @@ class KernelSupportLibrary {
}
void For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ int64 step,
+ const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
+ For(name, start, end, ir_builder_->getInt64(step),
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
+ }
+
+ void For(
tensorflow::StringPiece name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
For(name, /*start=*/ir_builder_->getInt64(start),