aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc303
1 files changed, 127 insertions, 176 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 58228180ca..645888de78 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -49,15 +49,15 @@ class MemoryTile {
// `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
// `major_dim_offset` in the major dimension. The tile size along the minor
// dimension is the vector size, and that is implicitly determined by `vsl`.
- MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
+ MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b,
llvm::Value* matrix, int64 matrix_size_along_minor_dim,
llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
- : vsl_(vsl), ir_builder_(ir_builder) {
+ : vsl_(vsl), b_(b) {
pointers_.reserve(tile_size_along_major_dim);
for (int64 i = 0; i < tile_size_along_major_dim; i++) {
- llvm::Value* total_offset = ir_builder->CreateMul(
- ir_builder->getInt64(matrix_size_along_minor_dim),
- ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset));
+ llvm::Value* total_offset =
+ b->CreateMul(b->getInt64(matrix_size_along_minor_dim),
+ b->CreateAdd(b->getInt64(i), major_dim_offset));
pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
}
}
@@ -101,8 +101,7 @@ class MemoryTile {
for (int64 i = 0; i < pointers_.size(); i++) {
for (int64 j = 0; j < tile_size_along_middle_dim; j++) {
result[i].push_back(vsl_->LoadBroadcast(
- pointers_[i], ir_builder_->CreateAdd(minor_dim_offset,
- ir_builder_->getInt64(j))));
+ pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j))));
}
}
return result;
@@ -110,7 +109,7 @@ class MemoryTile {
private:
VectorSupportLibrary* vsl_;
- llvm::IRBuilder<>* ir_builder_;
+ llvm::IRBuilder<>* b_;
std::vector<llvm::Value*> pointers_;
};
@@ -249,16 +248,15 @@ class ColumnMajorMatrixVectorProductEmitter
ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* addend,
llvm::Value* result,
- llvm::IRBuilder<>* ir_builder)
+ llvm::IRBuilder<>* b)
: config_(config),
lhs_(lhs),
rhs_(rhs),
addend_(addend),
result_(result),
- ir_builder_(ir_builder),
- ksl_(ir_builder_),
- vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(),
- ir_builder_, "") {
+ b_(b),
+ ksl_(b_),
+ vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") {
CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows())));
CHECK(!has_addend() || addend != nullptr);
}
@@ -272,7 +270,7 @@ class ColumnMajorMatrixVectorProductEmitter
bool is_first_column);
MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) {
- return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_,
+ return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
/*matrix_size_along_minor_dim=*/m(),
/*major_dim_offset=*/column_start,
/*tile_size_along_major_dim=*/column_count);
@@ -302,7 +300,7 @@ class ColumnMajorMatrixVectorProductEmitter
llvm::Value* rhs_;
llvm::Value* addend_;
llvm::Value* result_;
- llvm::IRBuilder<>* ir_builder_;
+ llvm::IRBuilder<>* b_;
KernelSupportLibrary ksl_;
VectorSupportLibrary vsl_;
};
@@ -331,7 +329,7 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() {
});
if (column_remainder != 0) {
- EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
+ EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder,
column_limit == 0);
}
}
@@ -364,7 +362,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
return;
}
- llvm::Value* columns_llvm = ir_builder_->getInt64(columns);
+ llvm::Value* columns_llvm = b_->getInt64(columns);
// for (col = current_tile_col; col < (columns + current_tile_col); col++)
// for (row = row_start, row < m_; row++) {
@@ -375,12 +373,11 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
ksl_.ForReturnVoid(
"dot.inner.epilg.outer", /*start=*/current_tile_col,
- /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
+ /*end=*/b_->CreateAdd(columns_llvm, current_tile_col),
/*step=*/1, /*peel_first_iteration=*/false,
[&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
- llvm::Value* total_offset =
- ir_builder_->CreateMul(col, ir_builder_->getInt64(m()));
+ llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
ksl_.ForReturnVoid(
@@ -388,9 +385,8 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
/*step=*/1, [&](llvm::Value* scalar_row) {
llvm::Value* product = vsl_.Mul(
vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
- llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
- is_first_scalar_col,
- ir_builder_->getInt1(is_first_tiled_column));
+ llvm::Value* setting_result_first_time = b_->CreateAnd(
+ is_first_scalar_col, b_->getInt1(is_first_tiled_column));
ksl_.IfReturnVoid(
setting_result_first_time,
/*true_block_generator=*/
@@ -478,16 +474,15 @@ class RowMajorMatrixVectorProductEmitter
RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* addend,
- llvm::Value* result,
- llvm::IRBuilder<>* ir_builder)
+ llvm::Value* result, llvm::IRBuilder<>* b)
: config_(config),
lhs_(lhs),
rhs_(rhs),
addend_(addend),
result_(result),
- ir_builder_(ir_builder),
- ksl_(ir_builder_),
- vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") {
+ b_(b),
+ ksl_(b_),
+ vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") {
CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols())));
CHECK(!has_addend() || addend != nullptr);
}
@@ -498,7 +493,7 @@ class RowMajorMatrixVectorProductEmitter
private:
MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) {
- return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_,
+ return MemoryTile(&vsl_, b_, /*matrix=*/lhs_,
/*matrix_size_along_minor_dim=*/k(),
/*major_dim_offset=*/row_start,
/*tile_size_along_major_dim=*/row_count);
@@ -517,7 +512,7 @@ class RowMajorMatrixVectorProductEmitter
llvm::Value* rhs_;
llvm::Value* addend_;
llvm::Value* result_;
- llvm::IRBuilder<>* ir_builder_;
+ llvm::IRBuilder<>* b_;
KernelSupportLibrary ksl_;
VectorSupportLibrary vsl_;
};
@@ -559,7 +554,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
for (int i = 0; i < row_count; i++) {
llvm::Value* result_value =
vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get());
- llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row);
+ llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row);
if (addend_ && row_count != vsl_.vector_size()) {
result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value);
}
@@ -578,7 +573,7 @@ void RowMajorMatrixVectorProductEmitter::Emit() {
[&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
if (row_remainder != 0) {
- EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
+ EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder);
}
}
@@ -609,9 +604,8 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
}
for (int r = 0; r < rows; r++) {
- llvm::Value* total_offset = ir_builder_->CreateMul(
- ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
- ir_builder_->getInt64(k()));
+ llvm::Value* total_offset = b_->CreateMul(
+ b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
ksl_.ForReturnVoid(
@@ -722,13 +716,13 @@ class MatrixMatrixBlockPanelEmitter {
// `lhs` with `rhs` and stores the result in `result`.
explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* result,
- llvm::IRBuilder<>* ir_builder)
+ llvm::IRBuilder<>* b)
: lhs_(lhs),
rhs_(rhs),
result_(result),
config_(config),
- ir_builder_(ir_builder),
- ksl_(ir_builder_) {
+ b_(b),
+ ksl_(b_) {
CHECK(max_vectorization_width() > 0 &&
IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
CHECK_GT(max_vector_count(), 0);
@@ -761,7 +755,7 @@ class MatrixMatrixBlockPanelEmitter {
int64 tile_size_m, llvm::Value* m_start,
llvm::Value* m_end);
- llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); }
+ llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); }
Config config() const { return config_; }
Dimensions dims() const { return config().dims(); }
@@ -782,7 +776,7 @@ class MatrixMatrixBlockPanelEmitter {
llvm::Value* result_;
Config config_;
- llvm::IRBuilder<>* ir_builder_;
+ llvm::IRBuilder<>* b_;
KernelSupportLibrary ksl_;
};
@@ -804,8 +798,8 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
current_vectorization_width >= min_vectorization_width()) {
int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
if (n_start != n_end) {
- VectorSupportLibrary vsl(scalar_type(), current_vectorization_width,
- ir_builder_, "gebp");
+ VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_,
+ "gebp");
HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
@@ -819,10 +813,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
}
if (n_start != dims().n()) {
- VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp");
+ VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp");
ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
- llvm::Value* n_i_next =
- ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
+ llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1));
HandleResiduesOnK(&vsl, n_i, n_i_next);
});
}
@@ -935,11 +928,11 @@ void MatrixMatrixBlockPanelEmitter::EmitTiledGemm(
ksl_.ForReturnVoid(
"dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
MemoryTile result_memory_tile(
- vsl, ir_builder_, /*matrix=*/result_,
+ vsl, b_, /*matrix=*/result_,
/*matrix_size_along_minor_dim=*/dims().n(),
/*major_dim_offset=*/m_i,
/*tile_size_along_major_dim=*/tile_size_m);
- MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
+ MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_,
/*matrix_size_along_minor_dim=*/dims().k(),
/*major_dim_offset=*/m_i,
/*tile_size_along_major_dim=*/tile_size_m);
@@ -949,8 +942,8 @@ void MatrixMatrixBlockPanelEmitter::EmitTiledGemm(
result_memory_tile.LoadTile(n_i));
ksl_.ForReturnVoid(
"dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
- MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_,
- dims().n(), k_i, tile_size_k);
+ MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i,
+ tile_size_k);
std::vector<std::vector<llvm::Value*>> lhs_tile =
lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
std::vector<llvm::Value*> rhs_tile =
@@ -980,7 +973,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
llvm::Value* executable_run_options_value,
- llvm::IRBuilder<>* ir_builder,
+ llvm::IRBuilder<>* b,
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features)
: dot_(dot),
@@ -989,7 +982,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
rhs_array_(rhs_array),
addend_array_(addend_array),
executable_run_options_value_(executable_run_options_value),
- ir_builder_(ir_builder),
+ b_(b),
hlo_module_config_(hlo_module_config),
target_machine_features_(target_machine_features) {}
@@ -997,15 +990,14 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
- llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
+ llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features) {
PrimitiveType type = target_array.GetShape().element_type();
TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array,
- addend_array, executable_run_options_value,
- ir_builder, hlo_module_config,
- target_machine_features);
+ addend_array, executable_run_options_value, b,
+ hlo_module_config, target_machine_features);
return dot_emitter.Emit();
}
@@ -1050,13 +1042,13 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
}
int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
- ir_builder_->CreateMemSet(
- target, ir_builder_->getInt8(0), size_bytes,
+ b_->CreateMemSet(
+ target, b_->getInt8(0), size_bytes,
target_machine_features_.minimum_alignment_for_allocation(size_bytes));
int64 max_target_vector_width =
target_machine_features_.vector_register_num_elements(
- *ir_builder_->GetInsertBlock()->getParent(), primitive_type);
+ *b_->GetInsertBlock()->getParent(), primitive_type);
int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
@@ -1080,12 +1072,12 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
- /*optimize_for_size=*/optimize_for_size, ir_builder_,
- config.GetCacheKey(), lhs, rhs, target,
+ /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs,
+ rhs, target,
[this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) {
- MatrixMatrixBlockPanelEmitter gebp_emitter(
- config, /*lhs=*/lhs, /*rhs=*/rhs,
- /*result=*/target, ir_builder_);
+ MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs,
+ /*rhs=*/rhs,
+ /*result=*/target, b_);
gebp_emitter.Emit();
});
@@ -1163,7 +1155,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
const int target_vector_register_element_size =
target_machine_features_.vector_register_num_elements(
- *ir_builder_->GetInsertBlock()->getParent(), primitive_type);
+ *b_->GetInsertBlock()->getParent(), primitive_type);
// We may not always know the vector register size for the target we're
// compiling against, in which case target_vector_register_element_size is 0.
@@ -1184,13 +1176,13 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
- /*optimize_for_size=*/optimize_for_size, ir_builder_,
- config.GetCacheKey(), lhs_op, rhs_op,
+ /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(),
+ lhs_op, rhs_op,
addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
[this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
llvm::Value* addend_op, llvm::Value* result_op) {
ColumnMajorMatrixVectorProductEmitter emitter(
- config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
+ config, lhs_op, rhs_op, addend_op, result_op, b_);
emitter.Emit();
});
} else {
@@ -1203,13 +1195,13 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
- /*optimize_for_size=*/optimize_for_size, ir_builder_,
- config.GetCacheKey(), lhs_op, rhs_op,
+ /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(),
+ lhs_op, rhs_op,
addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
[this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
llvm::Value* addend_op, llvm::Value* result_op) {
- RowMajorMatrixVectorProductEmitter emitter(
- config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
+ RowMajorMatrixVectorProductEmitter emitter(config, lhs_op, rhs_op,
+ addend_op, result_op, b_);
emitter.Emit();
});
}
@@ -1285,11 +1277,11 @@ Status DotOpEmitter::Emit() {
// Create loop nests which loop through the LHS operand dimensions and the RHS
// operand dimensions. The reduction dimension of the LHS and RHS are handled
// in a separate innermost loop which performs the sum of products.
- llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), ir_builder_);
- llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest(
- &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs");
- llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest(
- &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs");
+ llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_);
+ llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest(
+ lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
+ llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest(
+ rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
// Create the loop which does the sum of products reduction.
//
@@ -1319,62 +1311,55 @@ Status DotOpEmitter::Emit() {
// Function entry basic block.
// - Emit alloca for accumulator
llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
- SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_);
+ SetToFirstInsertPoint(&func->getEntryBlock(), b_);
llvm::Type* accum_type = target_array_.GetElementLlvmType();
- llvm::Value* accum_address = ir_builder_->CreateAlloca(
- accum_type, /*ArraySize=*/nullptr, "accum_address");
+ llvm::Value* accum_address =
+ b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");
// Preheader basic block of reduction loop:
// - Initialize accumulator to zero.
llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
- ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
+ b_->SetInsertPoint(preheader_bb->getTerminator());
- ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type),
- accum_address);
+ b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);
// Body basic block of reduction loop:
// - Load elements from lhs and rhs array.
// - Multiply lhs-element and rhs-element.
// - Load accumulator and add to product.
// - Store sum back into accumulator.
- SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_);
+ SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);
- llvm::Value* lhs_element =
- lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_);
- llvm::Value* rhs_element =
- rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
+ llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
+ llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);
- llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
+ llvm::Value* accum = b_->CreateLoad(accum_address);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
- auto real = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {0});
- };
- auto imag = [&](llvm::Value* x) {
- return ir_builder_->CreateExtractValue(x, {1});
- };
- llvm::Value* product_real = ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)),
- ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element)));
- llvm::Value* product_imag = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)),
- ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element)));
- updated_accum = ir_builder_->CreateInsertValue(
- accum, ir_builder_->CreateFAdd(real(accum), product_real), {0});
- updated_accum = ir_builder_->CreateInsertValue(
- updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1});
+ auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
+ auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
+ llvm::Value* product_real =
+ b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
+ b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
+ llvm::Value* product_imag =
+ b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
+ b_->CreateFMul(imag(lhs_element), real(rhs_element)));
+ updated_accum = b_->CreateInsertValue(
+ accum, b_->CreateFAdd(real(accum), product_real), {0});
+ updated_accum = b_->CreateInsertValue(
+ updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
} else {
- llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
- updated_accum = ir_builder_->CreateFAdd(accum, product);
+ llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
+ updated_accum = b_->CreateFAdd(accum, product);
}
- ir_builder_->CreateStore(updated_accum, accum_address);
+ b_->CreateStore(updated_accum, accum_address);
// Exit basic block of reduction loop.
// - Load accumulator value (the result).
// - Store into output array.
- SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_);
+ SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);
- llvm::Value* result = ir_builder_->CreateLoad(accum_address);
+ llvm::Value* result = b_->CreateLoad(accum_address);
// Create index into target address. The target index is the concatenation of
// the rhs and lhs indexes with the reduction dimensions removed. The terms
@@ -1392,11 +1377,11 @@ Status DotOpEmitter::Emit() {
}
}
- target_array_.EmitWriteArrayElement(target_index, result, ir_builder_);
+ target_array_.EmitWriteArrayElement(target_index, result, b_);
// Set the IR builder insert point to the exit basic block of the outer most
// loop.
- ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
+ b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
return Status::OK();
}
@@ -1405,31 +1390,30 @@ Status DotOpEmitter::EmitScalarDot() {
// A scalar dot is just a scalar multiply.
llvm::Value* result;
// Use the same index_type for all tensor accesses in the same kernel.
- llvm::Type* index_type = ir_builder_->getInt64Ty();
+ llvm::Type* index_type = b_->getInt64Ty();
llvm_ir::IrArray::Index element_index(index_type);
llvm::Value* lhs_value =
- lhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_);
+ lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
llvm::Value* rhs_value =
- rhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_);
+ rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
-#define REAL(x) ir_builder_->CreateExtractValue(x, {0})
-#define IMAG(x) ir_builder_->CreateExtractValue(x, {1})
- llvm::Value* real = ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)),
- ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value)));
- llvm::Value* imag = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)),
- ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value)));
+#define REAL(x) b_->CreateExtractValue(x, {0})
+#define IMAG(x) b_->CreateExtractValue(x, {1})
+ llvm::Value* real =
+ b_->CreateFSub(b_->CreateFMul(REAL(lhs_value), REAL(rhs_value)),
+ b_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value)));
+ llvm::Value* imag =
+ b_->CreateFAdd(b_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)),
+ b_->CreateFMul(IMAG(lhs_value), REAL(rhs_value)));
#undef IMAG
#undef REAL
result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
- result = ir_builder_->CreateInsertValue(result, real, {0});
- result = ir_builder_->CreateInsertValue(result, imag, {1});
+ result = b_->CreateInsertValue(result, real, {0});
+ result = b_->CreateInsertValue(result, imag, {1});
} else {
- result = ir_builder_->CreateFMul(lhs_value, rhs_value);
+ result = b_->CreateFMul(lhs_value, rhs_value);
}
- target_array_.EmitWriteArrayElement(/*index=*/element_index, result,
- ir_builder_);
+ target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
return Status::OK();
}
@@ -1452,7 +1436,7 @@ Status DotOpEmitter::EmitCallToRuntime() {
fn_name = multi_threaded
? runtime::kEigenMatMulF16SymbolName
: runtime::kEigenSingleThreadedMatMulF16SymbolName;
- float_type = ir_builder_->getHalfTy();
+ float_type = b_->getHalfTy();
break;
case F32:
fn_name = multi_threaded
@@ -1461,7 +1445,7 @@ Status DotOpEmitter::EmitCallToRuntime() {
: (use_mkl_dnn
? runtime::kMKLSingleThreadedMatMulF32SymbolName
: runtime::kEigenSingleThreadedMatMulF32SymbolName);
- float_type = ir_builder_->getFloatTy();
+ float_type = b_->getFloatTy();
break;
case F64:
fn_name = multi_threaded
@@ -1470,7 +1454,7 @@ Status DotOpEmitter::EmitCallToRuntime() {
: (use_mkl_dnn
? runtime::kMKLSingleThreadedMatMulF64SymbolName
: runtime::kEigenSingleThreadedMatMulF64SymbolName);
- float_type = ir_builder_->getDoubleTy();
+ float_type = b_->getDoubleTy();
break;
default:
return Unimplemented("Invalid type %s for dot operation",
@@ -1478,16 +1462,16 @@ Status DotOpEmitter::EmitCallToRuntime() {
}
llvm::Type* float_ptr_type = float_type->getPointerTo();
- llvm::Type* int64_type = ir_builder_->getInt64Ty();
- llvm::Type* int32_type = ir_builder_->getInt32Ty();
- llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo();
+ llvm::Type* int64_type = b_->getInt64Ty();
+ llvm::Type* int32_type = b_->getInt32Ty();
+ llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo();
llvm::FunctionType* matmul_type = llvm::FunctionType::get(
- ir_builder_->getVoidTy(),
+ b_->getVoidTy(),
{int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
int64_type, int64_type, int64_type, int32_type, int32_type},
/*isVarArg=*/false);
- llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
+ llvm::Function* function = b_->GetInsertBlock()->getParent();
llvm::Module* module = function->getParent();
llvm::Function* matmul_func = llvm::cast<llvm::Function>(
@@ -1522,18 +1506,15 @@ Status DotOpEmitter::EmitCallToRuntime() {
std::swap(transpose_lhs, transpose_rhs);
}
- ir_builder_->CreateCall(
+ b_->CreateCall(
matmul_func,
- {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
- ir_builder_->CreateBitCast(target_array_.GetBasePointer(),
- float_ptr_type),
- ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
- ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
- ir_builder_->getInt64(mat_mult_dims.m),
- ir_builder_->getInt64(mat_mult_dims.n),
- ir_builder_->getInt64(mat_mult_dims.k),
- ir_builder_->getInt32(transpose_lhs),
- ir_builder_->getInt32(transpose_rhs)});
+ {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
+ b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type),
+ b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
+ b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
+ b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
+ b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
+ b_->getInt32(transpose_rhs)});
return Status::OK();
}
@@ -1556,36 +1537,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0};
}
-llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
- llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
- int64 reduction_dimension, tensorflow::StringPiece name_suffix) {
- // Prepares the dimension list we will use to emit the loop nest. Outermost
- // loops are added first. Add loops in major-to-minor order, and skip the
- // reduction dimension.
- std::vector<int64> dimensions;
- const Shape& shape = operand_array.GetShape();
- for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) {
- int64 dimension = LayoutUtil::Minor(shape.layout(), i);
- if (dimension != reduction_dimension) {
- dimensions.push_back(dimension);
- }
- }
-
- // Create loop nest with one for-loop for each dimension of the
- // output.
- llvm_ir::IrArray::Index index =
- loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
- // Verify every dimension except the reduction dimension was set in the index.
- for (int dimension = 0; dimension < index.size(); ++dimension) {
- if (dimension == reduction_dimension) {
- DCHECK_EQ(nullptr, index[dimension]);
- } else {
- DCHECK_NE(nullptr, index[dimension]);
- }
- }
- return index;
-}
-
// Return whether the given shape is a matrix with no padding.
static bool IsRank2WithNoPadding(const Shape& shape) {
return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);