diff options
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emission_utils.h | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 177 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.h | 18 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_function.cc | 184 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_function.h | 49 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h | 12 |
8 files changed, 257 insertions, 208 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ade887f193..6e7b062280 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -287,10 +287,15 @@ cc_library( srcs = ["ir_function.cc"], hdrs = ["ir_function.h"], deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", "@llvm//:core", ], ) @@ -300,6 +305,7 @@ cc_library( srcs = ["parallel_loop_emitter.cc"], hdrs = ["parallel_loop_emitter.h"], deps = [ + ":ir_emission_utils", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", @@ -645,6 +651,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", + "@llvm//:core", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index ac361ddfb4..34b2003916 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -23,6 +24,19 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); + +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// See IrFunction and ParallelLoopEmitter for details. +using DynamicLoopBounds = std::vector<std::pair<llvm::Value*, llvm::Value*>>; + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 939dbf0e11..70e7aec5c5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1361,7 +1361,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Argument* params = compute_function_->parameters_arg(); + llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = @@ -2201,9 +2201,17 @@ Status IrEmitter::HandleCall(HloInstruction* call) { !parallel_cpu_backend_) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. - TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, - emitted_value_[call], computation, - call_ir_function)); + std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, computation->name(), + /*return_value_buffer=*/emitted_value_[call], + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument()); + + HloInstruction* root = computation->root_instruction(); + TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( + call_args, root->shape(), root->outer_dimension_partitions(), + &ir_builder_, call_ir_function, computation->name())); } else { EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, emitted_value_[call], computation->name()); @@ -2678,7 +2686,7 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, module_); } -llvm::Argument* IrEmitter::GetProfileCountersArgument() { +llvm::Value* IrEmitter::GetProfileCountersArgument() { return compute_function_->profile_counters_arg(); } @@ -2752,42 +2760,6 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits code to allocate an array of parameter address pointers, and store -// each address from 'parameter_addresses'. -// Returns an array of compute function call arguments (including parameter -// address buffer). -std::vector<llvm::Value*> IrEmitter::GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt8PtrTy(), - ir_builder_.getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - &ir_builder_); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( - parameter_addresses[i], ir_builder_.getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); - } - - const auto to_int8_ptr = [this](llvm::Value* ptr) { - return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy()); - }; - std::vector<llvm::Value*> arguments{ - to_int8_ptr(return_value_buffer), - to_int8_ptr(GetExecutableRunOptionsArgument()), - parameter_addresses_buffer, GetTempBuffersArgument()}; - if (auto* profile_counters = GetProfileCountersArgument()) { - arguments.push_back(profile_counters); - } - return arguments; -} - // Emits a core function call based on the following pseudo-code. // // char** parameter_addresses_buffer = @@ -2803,8 +2775,12 @@ void IrEmitter::EmitArrayFunctionCallInto( tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments(parameter_addresses, - return_value_buffer, name)); + function, GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -2824,111 +2800,6 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -// Emits a call to a runtime fork/join function which dispatches parallel -// calls to 'parallel_function' (and joins threads before returning). -Status IrEmitter::EmitParallelForkJoin( - tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function) { - HloInstruction* root = computation->root_instruction(); - - // Build ParallelForkJoin function type. - std::vector<llvm::Type*> compute_function_params = - compute_function_->GetComputeFunctionParams(); - // Number of parallel compute functions. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Array of partitions. There is an array element for each - // partition x partition_dim x 2 (for dimension start and limit). - compute_function_params.push_back( - llvm::Type::getInt64PtrTy(module_->getContext())); - // Number of partitioned most-major dimensions in 'root.shape'. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Function pointer for compute function to be dispatched in parallel. - compute_function_params.push_back( - llvm::Type::getInt8PtrTy(module_->getContext())); - - llvm::FunctionType* fork_join_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, - /*isVarArg=*/false); - - llvm::Function* fork_join_func = - llvm::cast<llvm::Function>(module_->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); - fork_join_func->setCallingConv(llvm::CallingConv::C); - fork_join_func->setDoesNotThrow(); - - // Add common compute function arguments. - const string name = computation->name(); - std::vector<llvm::Value*> arguments = - GetArrayFunctionCallArguments(parameter_addresses, output_address, name); - - // Create ShapePartitionIterator to generate all partitions of 'root.shape'. - ShapePartitionIterator partition_iterator(root->shape(), - root->outer_dimension_partitions()); - const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); - // Add argument specifying the number of parallel partitions. - arguments.push_back(ir_builder_.getInt32(num_partitions)); - - // The number of partitioned most-major dimensions in 'root.shape'. - const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); - // A dimension partition consists of two elements: [start_index, limit_index). - const int32 dim_partition_size = 2; - // Calculate array partition stride. - const int32 array_partition_stride = - num_partitioned_dims * dim_partition_size; - // Calculate the total number of elements in the partition array. - const int32 partition_array_size = - dim_partition_size * num_partitioned_dims * num_partitions; - - // Store dimension partition values as llvm constants in 'partitions'. - // See comments in runtime_fork_join.cc for array layout description. - std::vector<llvm::Constant*> partitions(partition_array_size); - for (int32 i = 0; i < num_partitions; ++i) { - std::vector<std::pair<int64, int64>> dim_partitions = - partition_iterator.GetPartition(i); - CHECK_EQ(num_partitioned_dims, dim_partitions.size()); - const int32 partition_index = i * array_partition_stride; - for (int32 j = 0; j < num_partitioned_dims; ++j) { - const std::pair<int64, int64>& dim_partition = dim_partitions[j]; - const int32 index = partition_index + j * dim_partition_size; - // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder_.getInt64(dim_partition.first); - partitions[index + 1] = - ir_builder_.getInt64(dim_partition.first + dim_partition.second); - } - } - - // Create global variable out of dimension partitions in 'partitions'. - llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); - llvm::Constant* partitions_array = - llvm::ConstantArray::get(partitions_array_type, partitions); - llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/partitions_array_type, - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/partitions_array, - /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); - - // Add argument specifying parallel dimension partitions. - arguments.push_back(ir_builder_.CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module_->getContext()))); - // Add argument specifying the number of partitioned most-major dimensions. - arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); - // Add argument for parallel compute function pointer. - arguments.push_back( - ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); - // Emit call to parallel fork/join. - ir_builder_.CreateCall(fork_join_func, arguments); - - return Status::OK(); -} - Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { llvm::Value* addr; const Shape& target_shape = op->shape(); @@ -2997,14 +2868,8 @@ Status IrEmitter::EmitTargetElementLoop( } else { if (ShouldEmitParallelLoopFor(*target_op)) { // Emit code to read dynamic loop bounds from compute function argument. - ParallelLoopEmitter::LoopBounds dynamic_loop_bounds( - num_dynamic_loop_bounds_); - for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { - dynamic_loop_bounds[i].first = - compute_function_->GetDynamicLoopBound(i * 2 + 0); - dynamic_loop_bounds[i].second = - compute_function_->GetDynamicLoopBound(i * 2 + 1); - } + std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds = + compute_function_->GetDynamicLoopBounds(); // Emit parallel loop with dynamic loop bounds for most-major dimensions. TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, &dynamic_loop_bounds, &ir_builder_) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 6b576d16bb..692e2b3877 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -237,7 +237,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Get the llvm::Value* that represents the "prof_counters" argument of the // computation function being emitted by this emitter. - llvm::Argument* GetProfileCountersArgument(); + llvm::Value* GetProfileCountersArgument(); // Get the xla::ExecutableRunOptions that represents the "run_options" // argument of the computation function being emitted by this emitter. @@ -300,18 +300,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, tensorflow::StringPiece name); - // Returns an array of compute function call arguments. - std::vector<llvm::Value*> GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Emits a call to a runtime fork/join function which dispatches parallel - // calls to 'parallel_function' (and joins threads before returning). - Status EmitParallelForkJoin( - tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function); - // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -493,7 +481,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { use_rdtscp_(false), prof_counters_(nullptr) {} ProfilingState(bool is_top_level_computation, bool use_rdtscp, - llvm::Argument* prof_counters) + llvm::Value* prof_counters) : is_top_level_computation_(is_top_level_computation), use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} @@ -526,7 +514,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool use_rdtscp_; // The argument which corresponds to the profile counter buffer. - llvm::Argument* prof_counters_; + llvm::Value* prof_counters_; // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index ed257613d8..ca8c290dd1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -28,6 +30,21 @@ using llvm_ir::AsStringRef; namespace cpu { +static std::vector<llvm::Type*> GetComputeFunctionParams( + llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = + llvm::Type::getInt64PtrTy(llvm_module->getContext()); + std::vector<llvm::Type*> compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds > 0) { + compute_function_params.push_back(i64_ptr_type); + } + compute_function_params.push_back(i64_ptr_type); + return compute_function_params; +} + IrFunction::IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, const bool optimize_for_size_requested, @@ -47,6 +64,15 @@ IrFunction::~IrFunction() { ir_builder_->CreateRetVoid(); } +DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { + DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_); + for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0); + dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1); + } + return dynamic_loop_bounds; +} + void IrFunction::Initialize(const string& function_name, llvm::Function::LinkageTypes linkage, const bool optimize_for_size_requested, @@ -106,7 +132,8 @@ void IrFunction::Initialize(const string& function_name, // to use GEPs to unravel the indirection layers. llvm::FunctionType* function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), - /*Params=*/GetComputeFunctionParams(), + /*Params=*/ + GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_), /*isVarArg=*/false); // Functions with local linkage get an inlining bonus. Because we know @@ -153,21 +180,6 @@ void IrFunction::Initialize(const string& function_name, /*Parent=*/function_)); } -std::vector<llvm::Type*> IrFunction::GetComputeFunctionParams() { - llvm::Type* i8_ptr_type = - llvm::Type::getInt8PtrTy(llvm_module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = - llvm::Type::getInt64PtrTy(llvm_module_->getContext()); - std::vector<llvm::Type*> compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (num_dynamic_loop_bounds_ > 0) { - compute_function_params.push_back(i64_ptr_type); - } - compute_function_params.push_back(i64_ptr_type); - return compute_function_params; -} - llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); @@ -177,5 +189,145 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { ir_builder_->getInt64(offset), AsStringRef(name))); } +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector<llvm::Value*> GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder->getInt8PtrTy(), + ir_builder->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), + ir_builder); + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( + parameter_addresses[i], ir_builder->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); + llvm::Value* slot_in_param_adresses = ir_builder->CreateInBoundsGEP( + parameter_addresses_buffer, {ir_builder->getInt64(i)}); + ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_adresses); + } + + const auto to_int8_ptr = [=](llvm::Value* ptr) { + return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + }; + std::vector<llvm::Value*> arguments{ + to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), + parameter_addresses_buffer, temp_buffers_arg}; + if (profile_counters_arg != nullptr) { + arguments.push_back(profile_counters_arg); + } + return arguments; +} + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector<llvm::Value*>& arguments, const Shape& shape, + const std::vector<int64>& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + + // Build ParallelForkJoin function type. + std::vector<llvm::Type*> compute_function_params = + GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module->getContext())); + // Number of partitioned most-major dimensions in 'shape'. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast<llvm::Function>(module->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + std::vector<llvm::Value*> fork_join_arguments(arguments); + + // Create ShapePartitionIterator to generate all partitions of 'shape'. + ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'shape'. + const int32 num_partitioned_dims = dimension_partition_counts.size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector<llvm::Constant*> partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector<std::pair<int64, int64>> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair<int64, int64>& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder->getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + fork_join_arguments.push_back(ir_builder->CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + fork_join_arguments.push_back( + ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder->CreateCall(fork_join_func, fork_join_arguments); + + return Status::OK(); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index b7516b403e..1fd2da4dce 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -20,8 +20,11 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -54,12 +57,15 @@ class IrFunction { llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); ~IrFunction(); - // Returns an array of compute function parameter types. - std::vector<llvm::Type*> GetComputeFunctionParams(); - - // Emit ir to read and return the ir value for the dynamic loop bound at - // 'offset' from the "dynamic_loop_bounds" argument of this function. - llvm::Value* GetDynamicLoopBound(int64 offset); + // Emit ir to read and return the set of ir values representing the dynamic + // loop bounds argument of this function. + // Each element in returned vector is a pair of ir values representing + // the loop bounds for a specific dimension, where the first element of the + // pair is the dimension start index, and the second element of the pair + // is the dimension limit. + // EX: [dimension_i_index_start_ir_value, dimension_i_index_limit_ir_value] + // + DynamicLoopBounds GetDynamicLoopBounds(); // Returns the encapculated llvm::Function. llvm::Function* function() { return function_; } @@ -71,15 +77,15 @@ class IrFunction { // "run_options" argument. llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; } - // Get the llvm::Argument that represents this functions parameters argument. - llvm::Argument* parameters_arg() { return parameters_arg_; } + // Get the llvm::Value* that represents this functions parameters argument. + llvm::Value* parameters_arg() { return parameters_arg_; } // Get the llvm::Value* that represents this functions "temps" argument. llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } // Get the llvm::Value* that represents this functions "prof_counters" // argument. - llvm::Argument* profile_counters_arg() { return profile_counters_arg_; } + llvm::Value* profile_counters_arg() { return profile_counters_arg_; } private: // Initialize an llvm::Function with standard signature based on arguments. @@ -87,6 +93,10 @@ class IrFunction { llvm::Function::LinkageTypes linkage, bool optimize_for_size_requested, bool enable_fast_math); + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of this function. + llvm::Value* GetDynamicLoopBound(int64 offset); + llvm::IRBuilder<>* ir_builder_; llvm::Module* llvm_module_; llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; @@ -97,12 +107,27 @@ class IrFunction { // Function argument IR values. llvm::Argument* result_arg_; llvm::Value* exec_run_options_arg_; - llvm::Argument* parameters_arg_; + llvm::Value* parameters_arg_; llvm::Value* temp_buffers_arg_; - llvm::Argument* dynamic_loop_bounds_arg_ = nullptr; - llvm::Argument* profile_counters_arg_; + llvm::Value* dynamic_loop_bounds_arg_ = nullptr; + llvm::Value* profile_counters_arg_; }; +// Returns an array of compute function call argument ir values. +std::vector<llvm::Value*> GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector<llvm::Value*>& arguments, const Shape& shape, + const std::vector<int64>& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name); + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 91e704e3d0..a3c3c1e5ef 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -24,8 +24,8 @@ namespace cpu { ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, - const llvm_ir::IrArray& target_array, const LoopBounds* dynamic_loop_bounds, - llvm::IRBuilder<>* ir_builder) + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) : LoopEmitter(target_element_generator, target_array, ir_builder), dynamic_loop_bounds_(dynamic_loop_bounds) {} diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 492d5953c4..9335d2818e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" @@ -31,9 +32,8 @@ namespace cpu { // [start, limit) pairs of ir values (one for each partitioned outer dimension). // // EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major -// dimensions dynamic. -// Then 'dynamic_loop_bounds' will contain the following ir values for -// the two most-major dimenions: +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: // [dim0_index_start_ir_value, dim0_index_limit_ir_value] // [dim1_index_start_ir_value, dim1_index_limit_ir_value] // @@ -47,15 +47,13 @@ namespace cpu { // class ParallelLoopEmitter : public llvm_ir::LoopEmitter { public: - using LoopBounds = std::vector<std::pair<llvm::Value*, llvm::Value*>>; - // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the // most-major dimensions, and 'target_array.' shape to set the static loop // bounds for the most-minor dimensions. ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const LoopBounds* dynamic_loop_bounds, + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; @@ -66,7 +64,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { tensorflow::StringPiece loop_name) override; private: - const LoopBounds* dynamic_loop_bounds_; + const DynamicLoopBounds* dynamic_loop_bounds_; }; } // namespace cpu |