aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_function.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_function.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc71
1 files changed, 33 insertions, 38 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 2d6f2f3818..6aff838462 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -49,11 +49,10 @@ IrFunction::IrFunction(const string& function_name,
llvm::Function::LinkageTypes linkage,
const bool optimize_for_size_requested,
const bool enable_fast_math, llvm::Module* llvm_module,
- llvm::IRBuilder<>* ir_builder,
- int64 num_dynamic_loop_bounds)
- : ir_builder_(ir_builder),
+ llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds)
+ : b_(b),
llvm_module_(llvm_module),
- caller_insert_point_guard_(*ir_builder),
+ caller_insert_point_guard_(*b),
num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
Initialize(function_name, linkage, optimize_for_size_requested,
enable_fast_math);
@@ -61,7 +60,7 @@ IrFunction::IrFunction(const string& function_name,
IrFunction::~IrFunction() {
// Emit function return value.
- ir_builder_->CreateRetVoid();
+ b_->CreateRetVoid();
}
DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
@@ -174,7 +173,7 @@ void IrFunction::Initialize(const string& function_name,
function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias);
}
- ir_builder_->SetInsertPoint(llvm::BasicBlock::Create(
+ b_->SetInsertPoint(llvm::BasicBlock::Create(
/*Context=*/llvm_module_->getContext(),
/*Name=*/"entry",
/*Parent=*/function_));
@@ -184,9 +183,8 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
CHECK_GT(num_dynamic_loop_bounds_, 0);
CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset);
- return ir_builder_->CreateLoad(
- ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
- ir_builder_->getInt64(offset), AsStringRef(name)));
+ return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
+ b_->getInt64(offset), AsStringRef(name)));
}
// Emits code to allocate an array of parameter address pointers, and store
@@ -195,27 +193,25 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
// address buffer).
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* b, tensorflow::StringPiece name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- ir_builder->getInt8PtrTy(),
- ir_builder->getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"),
- ir_builder);
+ b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
+ tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
- llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast(
- parameter_addresses[i], ir_builder->getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i,
- "_address_as_i8ptr")));
- llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP(
- parameter_addresses_buffer, {ir_builder->getInt64(i)});
- ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
+ llvm::Value* parameter_as_i8ptr =
+ b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
+ AsStringRef(tensorflow::strings::StrCat(
+ name, "_parameter_", i, "_address_as_i8ptr")));
+ llvm::Value* slot_in_param_addresses =
+ b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
+ b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
}
const auto to_int8_ptr = [=](llvm::Value* ptr) {
- return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy());
+ return b->CreatePointerCast(ptr, b->getInt8PtrTy());
};
std::vector<llvm::Value*> arguments{
to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
@@ -230,22 +226,21 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
// calls to 'parallel_function' (and joins threads before returning).
Status EmitCallToParallelForkJoin(
const std::vector<llvm::Value*>& arguments, const Shape& shape,
- const std::vector<int64>& dimension_partition_counts,
- llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function,
- const string& name) {
- llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
+ const std::vector<int64>& dimension_partition_counts, llvm::IRBuilder<>* b,
+ llvm::Function* parallel_function, const string& name) {
+ llvm::Module* module = b->GetInsertBlock()->getModule();
// Build ParallelForkJoin function type.
std::vector<llvm::Type*> compute_function_params =
GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
// Number of parallel compute functions.
- compute_function_params.push_back(ir_builder->getInt32Ty());
+ compute_function_params.push_back(b->getInt32Ty());
// Array of partitions. There is an array element for each
// partition x partition_dim x 2 (for dimension start and limit).
compute_function_params.push_back(
llvm::Type::getInt64PtrTy(module->getContext()));
// Number of partitioned most-major dimensions in 'shape'.
- compute_function_params.push_back(ir_builder->getInt32Ty());
+ compute_function_params.push_back(b->getInt32Ty());
// Function pointer for compute function to be dispatched in parallel.
compute_function_params.push_back(
llvm::Type::getInt8PtrTy(module->getContext()));
@@ -268,7 +263,7 @@ Status EmitCallToParallelForkJoin(
ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
// Add argument specifying the number of parallel partitions.
- fork_join_arguments.push_back(ir_builder->getInt32(num_partitions));
+ fork_join_arguments.push_back(b->getInt32(num_partitions));
// The number of partitioned most-major dimensions in 'shape'.
const int32 num_partitioned_dims = dimension_partition_counts.size();
@@ -293,15 +288,15 @@ Status EmitCallToParallelForkJoin(
const std::pair<int64, int64>& dim_partition = dim_partitions[j];
const int32 index = partition_index + j * dim_partition_size;
// Store partition [dim_start, dim_limit) intervals for each dimension.
- partitions[index] = ir_builder->getInt64(dim_partition.first);
+ partitions[index] = b->getInt64(dim_partition.first);
partitions[index + 1] =
- ir_builder->getInt64(dim_partition.first + dim_partition.second);
+ b->getInt64(dim_partition.first + dim_partition.second);
}
}
// Create global variable out of dimension partitions in 'partitions'.
llvm::ArrayType* partitions_array_type =
- llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size);
+ llvm::ArrayType::get(b->getInt64Ty(), partition_array_size);
llvm::Constant* partitions_array =
llvm::ConstantArray::get(partitions_array_type, partitions);
llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
@@ -315,16 +310,16 @@ Status EmitCallToParallelForkJoin(
tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
// Add argument specifying parallel dimension partitions.
- fork_join_arguments.push_back(ir_builder->CreateBitCast(
- global_partitions_array,
- llvm::Type::getInt64PtrTy(module->getContext())));
+ fork_join_arguments.push_back(
+ b->CreateBitCast(global_partitions_array,
+ llvm::Type::getInt64PtrTy(module->getContext())));
// Add argument specifying the number of partitioned most-major dimensions.
- fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims));
+ fork_join_arguments.push_back(b->getInt32(num_partitioned_dims));
// Add argument for parallel compute function pointer.
fork_join_arguments.push_back(
- ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy()));
+ b->CreateBitCast(parallel_function, b->getInt8PtrTy()));
// Emit call to parallel fork/join.
- ir_builder->CreateCall(fork_join_func, fork_join_arguments);
+ b->CreateCall(fork_join_func, fork_join_arguments);
return Status::OK();
}