aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-23 17:49:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 17:52:23 -0700
commit335d52c17644f417bc53abe4ef87ead9de01ad6d (patch)
tree94a11f269aa54fe37f1523a72049d1d18af92c2a
parent0f5f23e876be89dc2a389be078289d3028ae6503 (diff)
Cache generated LLVM IR for GEBP
After this change all generated GEBPs with the same shape will share a single llvm::Function. This is NFC for any actual workloads because the GEBP emitter isn't exercised by normal code-paths yet. PiperOrigin-RevId: 197820606
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc152
1 files changed, 102 insertions, 50 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 5158779910..3aa436b39a 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -610,16 +610,21 @@ class MatrixMatrixBlockPanelEmitter {
int64 k() const { return k_; }
int64 n() const { return n_; }
+ string ToString() const {
+ return tensorflow::strings::StrCat(m(), "x", k(), "x", n());
+ }
+
private:
const int64 m_;
const int64 k_;
const int64 n_;
};
- // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
- // `lhs` with `rhs` and stores the result in `result`.
+ // Represents the configuration of the GEBP emitter. The LLVM IR emitted by
+ // the emitter, modulo the LLVM values holding the input and output buffers,
+ // must be a function of the instance of `Config` passed to it.
//
- // `m`, `k` and `n` are the matrix multiplication dimensions.
+ // `dims` holds the matrix multiplication dimensions.
//
// `max_vectorization_width` is the maximum vector width (i.e. the width of
// the largest vector register we will use). This can be larger than the
@@ -630,27 +635,54 @@ class MatrixMatrixBlockPanelEmitter {
//
// `k_tiling_factor` is the number of elements along the reduction dimensions
// that we will attempt to process at once.
- explicit MatrixMatrixBlockPanelEmitter(
- llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims,
- int max_vectorization_width, int min_vectorization_width,
- int k_tiling_factor, const TargetMachineFeatures& target_machine_features,
- llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type)
+ class Config {
+ public:
+ explicit Config(PrimitiveType scalar_type, Dimensions dims,
+ int64 max_vectorization_width,
+ int64 min_vectorization_width, int64 k_tiling_factor)
+ : scalar_type_(scalar_type),
+ dims_(dims),
+ max_vectorization_width_(max_vectorization_width),
+ min_vectorization_width_(min_vectorization_width),
+ k_tiling_factor_(k_tiling_factor) {}
+
+ string GetCacheKey() const {
+ return tensorflow::strings::StrCat(
+ "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
+ "_", max_vectorization_width(), "_", min_vectorization_width(), "_",
+ k_tiling_factor());
+ }
+
+ PrimitiveType scalar_type() const { return scalar_type_; }
+ Dimensions dims() const { return dims_; }
+ int64 max_vectorization_width() const { return max_vectorization_width_; }
+ int64 min_vectorization_width() const { return min_vectorization_width_; }
+ int64 k_tiling_factor() const { return k_tiling_factor_; }
+
+ private:
+ PrimitiveType scalar_type_;
+ Dimensions dims_;
+ int64 max_vectorization_width_;
+ int64 min_vectorization_width_;
+ int64 k_tiling_factor_;
+ };
+
+ // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
+ // `lhs` with `rhs` and stores the result in `result`.
+ explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs,
+ llvm::Value* rhs, llvm::Value* result,
+ llvm::IRBuilder<>* ir_builder)
: lhs_(lhs),
rhs_(rhs),
result_(result),
- dims_(dims),
- max_vectorization_width_(max_vectorization_width),
- min_vectorization_width_(min_vectorization_width),
- k_tiling_factor_(k_tiling_factor),
- target_machine_features_(target_machine_features),
+ config_(config),
ir_builder_(ir_builder),
- primitive_type_(primitive_type),
ksl_(ir_builder_) {
- CHECK(max_vectorization_width > 0 &&
- IsPowerOfTwo(static_cast<uint64>(max_vectorization_width)));
- CHECK(min_vectorization_width > 0 &&
- IsPowerOfTwo(static_cast<uint64>(min_vectorization_width)));
- CHECK_GT(k_tiling_factor, 0);
+ CHECK(max_vectorization_width() > 0 &&
+ IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
+ CHECK(min_vectorization_width() > 0 &&
+ IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
+ CHECK_GT(k_tiling_factor(), 0);
}
void Emit();
@@ -677,31 +709,37 @@ class MatrixMatrixBlockPanelEmitter {
llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); }
+ Config config() const { return config_; }
+ Dimensions dims() const { return config().dims(); }
+
+ int64 max_vectorization_width() const {
+ return config().max_vectorization_width();
+ }
+ int64 min_vectorization_width() const {
+ return config().min_vectorization_width();
+ }
+ int64 k_tiling_factor() const { return config().k_tiling_factor(); }
+ PrimitiveType scalar_type() const { return config().scalar_type(); }
+
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* result_;
- Dimensions dims_;
-
- int64 max_vectorization_width_;
- int64 min_vectorization_width_;
- int64 k_tiling_factor_;
+ Config config_;
- const TargetMachineFeatures& target_machine_features_;
llvm::IRBuilder<>* ir_builder_;
- PrimitiveType primitive_type_;
KernelSupportLibrary ksl_;
};
void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); }
void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
- int64 current_vectorization_width = max_vectorization_width_;
+ int64 current_vectorization_width = max_vectorization_width();
int64 n_start = 0;
- while (n_start != dims_.n() &&
- current_vectorization_width >= min_vectorization_width_) {
- int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width);
+ while (n_start != dims().n() &&
+ current_vectorization_width >= min_vectorization_width()) {
+ int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
if (n_start != n_end) {
- VectorSupportLibrary vsl(primitive_type_, current_vectorization_width,
+ VectorSupportLibrary vsl(scalar_type(), current_vectorization_width,
ir_builder_, "gebp");
EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end));
n_start = n_end;
@@ -709,9 +747,9 @@ void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
current_vectorization_width /= 2;
}
- if (n_start != dims_.n()) {
- VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp");
- ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) {
+ if (n_start != dims().n()) {
+ VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp");
+ ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next =
ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
EmitLoopOverK(&vsl, n_i, n_i_next);
@@ -723,15 +761,15 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
llvm::Value* n_start,
llvm::Value* n_end) {
int64 k_start = 0;
- int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_);
+ int64 k_end = dims().k() - (dims().k() % k_tiling_factor());
if (k_end != k_start) {
- EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start,
- n_end, vsl);
+ EmitInnerLoop(k_tiling_factor(), getInt64(k_start), getInt64(k_end),
+ n_start, n_end, vsl);
k_start = k_end;
}
- if (k_start != dims_.k()) {
- EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()),
+ if (k_start != dims().k()) {
+ EmitInnerLoop(dims().k() - k_start, getInt64(k_start), getInt64(dims().k()),
n_start, n_end, vsl);
}
}
@@ -789,12 +827,12 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end,
llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) {
- ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) {
+ ksl_.For("dot.m", 0, dims().m(), 1, [&](llvm::Value* m_i) {
// This outer loop iterates over all of the M dimension
llvm::Value* result_row_begin = vsl->ComputeOffsetPointer(
- result_, /*offset_elements=*/m_i, /*scale=*/dims_.n());
+ result_, /*offset_elements=*/m_i, /*scale=*/dims().n());
llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer(
- lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k());
+ lhs_, /*offset_elements=*/m_i, /*scale=*/dims().k());
ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) {
// broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>,
@@ -808,7 +846,7 @@ void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
// rhs_loader will be used to load the tile off of the RHS, denoted as
// <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram.
- TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i,
+ TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims().n(), k_i,
k_tiling_factor);
ksl_.For(
"dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
@@ -913,14 +951,28 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
target, ir_builder_->getInt8(0), size_bytes,
target_machine_features_.minimum_alignment_for_allocation(size_bytes));
- MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k,
- /*n=*/n);
- MatrixMatrixBlockPanelEmitter gebp_emitter(
- /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims,
+ MatrixMatrixBlockPanelEmitter::Config config(
+ /*scalar_type=*/primitive_type,
+ MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
/*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
- /*k_tiling_factor=*/8, target_machine_features_, ir_builder_,
- primitive_type);
- gebp_emitter.Emit();
+ /*k_tiling_factor=*/8);
+
+ const bool enable_fast_math =
+ hlo_module_config_.debug_options().xla_enable_fast_math();
+ const bool optimize_for_size =
+ options::OptimizeForSizeRequested(hlo_module_config_);
+
+ KernelSupportLibrary::EmitAndCallOutlinedKernel(
+ /*enable_fast_math=*/enable_fast_math,
+ /*optimize_for_size=*/optimize_for_size, ir_builder_,
+ config.GetCacheKey(), lhs, rhs, target,
+ [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) {
+ MatrixMatrixBlockPanelEmitter gebp_emitter(
+ config, /*lhs=*/lhs, /*rhs=*/rhs,
+ /*result=*/target, ir_builder_);
+ gebp_emitter.Emit();
+ });
+
return true;
}