diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_function.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_function.cc | 71 |
1 files changed, 33 insertions, 38 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2d6f2f3818..6aff838462 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -49,11 +49,10 @@ IrFunction::IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, const bool optimize_for_size_requested, const bool enable_fast_math, llvm::Module* llvm_module, - llvm::IRBuilder<>* ir_builder, - int64 num_dynamic_loop_bounds) - : ir_builder_(ir_builder), + llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds) + : b_(b), llvm_module_(llvm_module), - caller_insert_point_guard_(*ir_builder), + caller_insert_point_guard_(*b), num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { Initialize(function_name, linkage, optimize_for_size_requested, enable_fast_math); @@ -61,7 +60,7 @@ IrFunction::IrFunction(const string& function_name, IrFunction::~IrFunction() { // Emit function return value. - ir_builder_->CreateRetVoid(); + b_->CreateRetVoid(); } DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { @@ -174,7 +173,7 @@ void IrFunction::Initialize(const string& function_name, function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); } - ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + b_->SetInsertPoint(llvm::BasicBlock::Create( /*Context=*/llvm_module_->getContext(), /*Name=*/"entry", /*Parent=*/function_)); @@ -184,9 +183,8 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_->CreateLoad( - ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), - ir_builder_->getInt64(offset), AsStringRef(name))); + return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + b_->getInt64(offset), AsStringRef(name))); } // Emits code to allocate an array of parameter address pointers, and store @@ -195,27 +193,25 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // address buffer). std::vector<llvm::Value*> GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder->getInt8PtrTy(), - ir_builder->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - ir_builder); + b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( - parameter_addresses[i], ir_builder->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder->getInt64(i)}); - ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + llvm::Value* parameter_as_i8ptr = + b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat( + name, "_parameter_", i, "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); } const auto to_int8_ptr = [=](llvm::Value* ptr) { - return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + return b->CreatePointerCast(ptr, b->getInt8PtrTy()); }; std::vector<llvm::Value*> arguments{ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), @@ -230,22 +226,21 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( // calls to 'parallel_function' (and joins threads before returning). Status EmitCallToParallelForkJoin( const std::vector<llvm::Value*>& arguments, const Shape& shape, - const std::vector<int64>& dimension_partition_counts, - llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, - const string& name) { - llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + const std::vector<int64>& dimension_partition_counts, llvm::IRBuilder<>* b, + llvm::Function* parallel_function, const string& name) { + llvm::Module* module = b->GetInsertBlock()->getModule(); // Build ParallelForkJoin function type. std::vector<llvm::Type*> compute_function_params = GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); // Number of parallel compute functions. - compute_function_params.push_back(ir_builder->getInt32Ty()); + compute_function_params.push_back(b->getInt32Ty()); // Array of partitions. There is an array element for each // partition x partition_dim x 2 (for dimension start and limit). compute_function_params.push_back( llvm::Type::getInt64PtrTy(module->getContext())); // Number of partitioned most-major dimensions in 'shape'. - compute_function_params.push_back(ir_builder->getInt32Ty()); + compute_function_params.push_back(b->getInt32Ty()); // Function pointer for compute function to be dispatched in parallel. compute_function_params.push_back( llvm::Type::getInt8PtrTy(module->getContext())); @@ -268,7 +263,7 @@ Status EmitCallToParallelForkJoin( ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); // Add argument specifying the number of parallel partitions. - fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + fork_join_arguments.push_back(b->getInt32(num_partitions)); // The number of partitioned most-major dimensions in 'shape'. const int32 num_partitioned_dims = dimension_partition_counts.size(); @@ -293,15 +288,15 @@ Status EmitCallToParallelForkJoin( const std::pair<int64, int64>& dim_partition = dim_partitions[j]; const int32 index = partition_index + j * dim_partition_size; // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index] = b->getInt64(dim_partition.first); partitions[index + 1] = - ir_builder->getInt64(dim_partition.first + dim_partition.second); + b->getInt64(dim_partition.first + dim_partition.second); } } // Create global variable out of dimension partitions in 'partitions'. llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::ArrayType::get(b->getInt64Ty(), partition_array_size); llvm::Constant* partitions_array = llvm::ConstantArray::get(partitions_array_type, partitions); llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( @@ -315,16 +310,16 @@ Status EmitCallToParallelForkJoin( tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. - fork_join_arguments.push_back(ir_builder->CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module->getContext()))); + fork_join_arguments.push_back( + b->CreateBitCast(global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); // Add argument specifying the number of partitioned most-major dimensions. - fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + fork_join_arguments.push_back(b->getInt32(num_partitioned_dims)); // Add argument for parallel compute function pointer. fork_join_arguments.push_back( - ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + b->CreateBitCast(parallel_function, b->getInt8PtrTy())); // Emit call to parallel fork/join. - ir_builder->CreateCall(fork_join_func, fork_join_arguments); + b->CreateCall(fork_join_func, fork_join_arguments); return Status::OK(); } |