aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD7
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.h14
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc177
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h18
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc184
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h49
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h12
8 files changed, 257 insertions, 208 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index ade887f193..6e7b062280 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -287,10 +287,15 @@ cc_library(
srcs = ["ir_function.cc"],
hdrs = ["ir_function.h"],
deps = [
+ ":ir_emission_utils",
+ ":shape_partition",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
"@llvm//:core",
],
)
@@ -300,6 +305,7 @@ cc_library(
srcs = ["parallel_loop_emitter.cc"],
hdrs = ["parallel_loop_emitter.h"],
deps = [
+ ":ir_emission_utils",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
@@ -645,6 +651,7 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/service:hlo",
+ "@llvm//:core",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
index ac361ddfb4..34b2003916 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_
+#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
@@ -23,6 +24,19 @@ namespace cpu {
bool PotentiallyImplementedAsEigenConvolution(
const HloInstruction& convolution);
+
+// Dynamic loop bounds are specified as an array of dimension index
+// [start, limit) pairs of ir values (one for each partitioned outer dimension).
+//
+// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major
+// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the
+// following ir values for the two most-major dimensions:
+// [dim0_index_start_ir_value, dim0_index_limit_ir_value]
+// [dim1_index_start_ir_value, dim1_index_limit_ir_value]
+//
+// See IrFunction and ParallelLoopEmitter for details.
+using DynamicLoopBounds = std::vector<std::pair<llvm::Value*, llvm::Value*>>;
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 939dbf0e11..70e7aec5c5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1361,7 +1361,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
//
// Where Param is the actual element type of the underlying buffer (for
// example, float for an XLA F32 element type).
- llvm::Argument* params = compute_function_->parameters_arg();
+ llvm::Value* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_);
llvm::LoadInst* param_address_untyped =
@@ -2201,9 +2201,17 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
!parallel_cpu_backend_) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
- TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses,
- emitted_value_[call], computation,
- call_ir_function));
+ std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
+ parameter_addresses, &ir_builder_, computation->name(),
+ /*return_value_buffer=*/emitted_value_[call],
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument());
+
+ HloInstruction* root = computation->root_instruction();
+ TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
+ call_args, root->shape(), root->outer_dimension_partitions(),
+ &ir_builder_, call_ir_function, computation->name()));
} else {
EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
emitted_value_[call], computation->name());
@@ -2678,7 +2686,7 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
return llvm_ir::ShapeToIrType(shape, module_);
}
-llvm::Argument* IrEmitter::GetProfileCountersArgument() {
+llvm::Value* IrEmitter::GetProfileCountersArgument() {
return compute_function_->profile_counters_arg();
}
@@ -2752,42 +2760,6 @@ llvm::Value* IrEmitter::EmitElementFunctionCall(
AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
}
-// Emits code to allocate an array of parameter address pointers, and store
-// each address from 'parameter_addresses'.
-// Returns an array of compute function call arguments (including parameter
-// address buffer).
-std::vector<llvm::Value*> IrEmitter::GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
- llvm::Value* parameter_addresses_buffer =
- llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- ir_builder_.getInt8PtrTy(),
- ir_builder_.getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"),
- &ir_builder_);
- for (size_t i = 0; i < parameter_addresses.size(); ++i) {
- llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast(
- parameter_addresses[i], ir_builder_.getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i,
- "_address_as_i8ptr")));
- llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP(
- parameter_addresses_buffer, {ir_builder_.getInt64(i)});
- ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses);
- }
-
- const auto to_int8_ptr = [this](llvm::Value* ptr) {
- return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy());
- };
- std::vector<llvm::Value*> arguments{
- to_int8_ptr(return_value_buffer),
- to_int8_ptr(GetExecutableRunOptionsArgument()),
- parameter_addresses_buffer, GetTempBuffersArgument()};
- if (auto* profile_counters = GetProfileCountersArgument()) {
- arguments.push_back(profile_counters);
- }
- return arguments;
-}
-
// Emits a core function call based on the following pseudo-code.
//
// char** parameter_addresses_buffer =
@@ -2803,8 +2775,12 @@ void IrEmitter::EmitArrayFunctionCallInto(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
ir_builder_.CreateCall(
- function, GetArrayFunctionCallArguments(parameter_addresses,
- return_value_buffer, name));
+ function, GetArrayFunctionCallArguments(
+ parameter_addresses, &ir_builder_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
}
llvm::Value* IrEmitter::EmitArrayFunctionCall(
@@ -2824,111 +2800,6 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall(
return return_value_buffer;
}
-// Emits a call to a runtime fork/join function which dispatches parallel
-// calls to 'parallel_function' (and joins threads before returning).
-Status IrEmitter::EmitParallelForkJoin(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* output_address, HloComputation* computation,
- llvm::Function* parallel_function) {
- HloInstruction* root = computation->root_instruction();
-
- // Build ParallelForkJoin function type.
- std::vector<llvm::Type*> compute_function_params =
- compute_function_->GetComputeFunctionParams();
- // Number of parallel compute functions.
- compute_function_params.push_back(ir_builder_.getInt32Ty());
- // Array of partitions. There is an array element for each
- // partition x partition_dim x 2 (for dimension start and limit).
- compute_function_params.push_back(
- llvm::Type::getInt64PtrTy(module_->getContext()));
- // Number of partitioned most-major dimensions in 'root.shape'.
- compute_function_params.push_back(ir_builder_.getInt32Ty());
- // Function pointer for compute function to be dispatched in parallel.
- compute_function_params.push_back(
- llvm::Type::getInt8PtrTy(module_->getContext()));
-
- llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
- /*Result=*/llvm::Type::getVoidTy(module_->getContext()),
- /*Params=*/compute_function_params,
- /*isVarArg=*/false);
-
- llvm::Function* fork_join_func =
- llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- runtime::kParallelForkJoinSymbolName, fork_join_type));
- fork_join_func->setCallingConv(llvm::CallingConv::C);
- fork_join_func->setDoesNotThrow();
-
- // Add common compute function arguments.
- const string name = computation->name();
- std::vector<llvm::Value*> arguments =
- GetArrayFunctionCallArguments(parameter_addresses, output_address, name);
-
- // Create ShapePartitionIterator to generate all partitions of 'root.shape'.
- ShapePartitionIterator partition_iterator(root->shape(),
- root->outer_dimension_partitions());
- const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
- // Add argument specifying the number of parallel partitions.
- arguments.push_back(ir_builder_.getInt32(num_partitions));
-
- // The number of partitioned most-major dimensions in 'root.shape'.
- const int32 num_partitioned_dims = root->outer_dimension_partitions().size();
- // A dimension partition consists of two elements: [start_index, limit_index).
- const int32 dim_partition_size = 2;
- // Calculate array partition stride.
- const int32 array_partition_stride =
- num_partitioned_dims * dim_partition_size;
- // Calculate the total number of elements in the partition array.
- const int32 partition_array_size =
- dim_partition_size * num_partitioned_dims * num_partitions;
-
- // Store dimension partition values as llvm constants in 'partitions'.
- // See comments in runtime_fork_join.cc for array layout description.
- std::vector<llvm::Constant*> partitions(partition_array_size);
- for (int32 i = 0; i < num_partitions; ++i) {
- std::vector<std::pair<int64, int64>> dim_partitions =
- partition_iterator.GetPartition(i);
- CHECK_EQ(num_partitioned_dims, dim_partitions.size());
- const int32 partition_index = i * array_partition_stride;
- for (int32 j = 0; j < num_partitioned_dims; ++j) {
- const std::pair<int64, int64>& dim_partition = dim_partitions[j];
- const int32 index = partition_index + j * dim_partition_size;
- // Store partition [dim_start, dim_limit) intervals for each dimension.
- partitions[index] = ir_builder_.getInt64(dim_partition.first);
- partitions[index + 1] =
- ir_builder_.getInt64(dim_partition.first + dim_partition.second);
- }
- }
-
- // Create global variable out of dimension partitions in 'partitions'.
- llvm::ArrayType* partitions_array_type =
- llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size);
- llvm::Constant* partitions_array =
- llvm::ConstantArray::get(partitions_array_type, partitions);
- llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
- /*Module=*/*module_,
- /*Type=*/partitions_array_type,
- /*isConstant=*/true,
- /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
- /*Initializer=*/partitions_array,
- /*Name=*/
- AsStringRef(
- tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
-
- // Add argument specifying parallel dimension partitions.
- arguments.push_back(ir_builder_.CreateBitCast(
- global_partitions_array,
- llvm::Type::getInt64PtrTy(module_->getContext())));
- // Add argument specifying the number of partitioned most-major dimensions.
- arguments.push_back(ir_builder_.getInt32(num_partitioned_dims));
- // Add argument for parallel compute function pointer.
- arguments.push_back(
- ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy()));
- // Emit call to parallel fork/join.
- ir_builder_.CreateCall(fork_join_func, arguments);
-
- return Status::OK();
-}
-
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
llvm::Value* addr;
const Shape& target_shape = op->shape();
@@ -2997,14 +2868,8 @@ Status IrEmitter::EmitTargetElementLoop(
} else {
if (ShouldEmitParallelLoopFor(*target_op)) {
// Emit code to read dynamic loop bounds from compute function argument.
- ParallelLoopEmitter::LoopBounds dynamic_loop_bounds(
- num_dynamic_loop_bounds_);
- for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
- dynamic_loop_bounds[i].first =
- compute_function_->GetDynamicLoopBound(i * 2 + 0);
- dynamic_loop_bounds[i].second =
- compute_function_->GetDynamicLoopBound(i * 2 + 1);
- }
+ std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
+ compute_function_->GetDynamicLoopBounds();
// Emit parallel loop with dynamic loop bounds for most-major dimensions.
TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
&dynamic_loop_bounds, &ir_builder_)
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 6b576d16bb..692e2b3877 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -237,7 +237,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Get the llvm::Value* that represents the "prof_counters" argument of the
// computation function being emitted by this emitter.
- llvm::Argument* GetProfileCountersArgument();
+ llvm::Value* GetProfileCountersArgument();
// Get the xla::ExecutableRunOptions that represents the "run_options"
// argument of the computation function being emitted by this emitter.
@@ -300,18 +300,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name);
- // Returns an array of compute function call arguments.
- std::vector<llvm::Value*> GetArrayFunctionCallArguments(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* return_value_buffer, tensorflow::StringPiece name);
-
- // Emits a call to a runtime fork/join function which dispatches parallel
- // calls to 'parallel_function' (and joins threads before returning).
- Status EmitParallelForkJoin(
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* output_address, HloComputation* computation,
- llvm::Function* parallel_function);
-
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
Status ElementTypesSameAndSupported(
@@ -493,7 +481,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
use_rdtscp_(false),
prof_counters_(nullptr) {}
ProfilingState(bool is_top_level_computation, bool use_rdtscp,
- llvm::Argument* prof_counters)
+ llvm::Value* prof_counters)
: is_top_level_computation_(is_top_level_computation),
use_rdtscp_(use_rdtscp),
prof_counters_(prof_counters) {}
@@ -526,7 +514,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
bool use_rdtscp_;
// The argument which corresponds to the profile counter buffer.
- llvm::Argument* prof_counters_;
+ llvm::Value* prof_counters_;
// The first read cycle counter in the program.
llvm::Value* first_read_cycle_start_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index ed257613d8..ca8c290dd1 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -28,6 +30,21 @@ using llvm_ir::AsStringRef;
namespace cpu {
+static std::vector<llvm::Type*> GetComputeFunctionParams(
+ llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) {
+ llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
+ llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
+ llvm::Type* i64_ptr_type =
+ llvm::Type::getInt64PtrTy(llvm_module->getContext());
+ std::vector<llvm::Type*> compute_function_params(
+ {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
+ if (num_dynamic_loop_bounds > 0) {
+ compute_function_params.push_back(i64_ptr_type);
+ }
+ compute_function_params.push_back(i64_ptr_type);
+ return compute_function_params;
+}
+
IrFunction::IrFunction(const string& function_name,
llvm::Function::LinkageTypes linkage,
const bool optimize_for_size_requested,
@@ -47,6 +64,15 @@ IrFunction::~IrFunction() {
ir_builder_->CreateRetVoid();
}
+DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
+ DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
+ for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
+ dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
+ dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
+ }
+ return dynamic_loop_bounds;
+}
+
void IrFunction::Initialize(const string& function_name,
llvm::Function::LinkageTypes linkage,
const bool optimize_for_size_requested,
@@ -106,7 +132,8 @@ void IrFunction::Initialize(const string& function_name,
// to use GEPs to unravel the indirection layers.
llvm::FunctionType* function_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
- /*Params=*/GetComputeFunctionParams(),
+ /*Params=*/
+ GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
/*isVarArg=*/false);
// Functions with local linkage get an inlining bonus. Because we know
@@ -153,21 +180,6 @@ void IrFunction::Initialize(const string& function_name,
/*Parent=*/function_));
}
-std::vector<llvm::Type*> IrFunction::GetComputeFunctionParams() {
- llvm::Type* i8_ptr_type =
- llvm::Type::getInt8PtrTy(llvm_module_->getContext());
- llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
- llvm::Type* i64_ptr_type =
- llvm::Type::getInt64PtrTy(llvm_module_->getContext());
- std::vector<llvm::Type*> compute_function_params(
- {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
- if (num_dynamic_loop_bounds_ > 0) {
- compute_function_params.push_back(i64_ptr_type);
- }
- compute_function_params.push_back(i64_ptr_type);
- return compute_function_params;
-}
-
llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
CHECK_GT(num_dynamic_loop_bounds_, 0);
CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
@@ -177,5 +189,145 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
ir_builder_->getInt64(offset), AsStringRef(name)));
}
+// Emits code to allocate an array of parameter address pointers, and store
+// each address from 'parameter_addresses'.
+// Returns an array of compute function call arguments (including parameter
+// address buffer).
+std::vector<llvm::Value*> GetArrayFunctionCallArguments(
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name,
+ llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
+ llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
+ llvm::Value* parameter_addresses_buffer =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ ir_builder->getInt8PtrTy(),
+ ir_builder->getInt32(parameter_addresses.size()),
+ tensorflow::strings::StrCat(name, "_parameter_addresses"),
+ ir_builder);
+ for (size_t i = 0; i < parameter_addresses.size(); ++i) {
+ llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast(
+ parameter_addresses[i], ir_builder->getInt8PtrTy(),
+ AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i,
+ "_address_as_i8ptr")));
+ llvm::Value* slot_in_param_adresses = ir_builder->CreateInBoundsGEP(
+ parameter_addresses_buffer, {ir_builder->getInt64(i)});
+ ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_adresses);
+ }
+
+ const auto to_int8_ptr = [=](llvm::Value* ptr) {
+ return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy());
+ };
+ std::vector<llvm::Value*> arguments{
+ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
+ parameter_addresses_buffer, temp_buffers_arg};
+ if (profile_counters_arg != nullptr) {
+ arguments.push_back(profile_counters_arg);
+ }
+ return arguments;
+}
+
+// Emits a call to a runtime fork/join function which dispatches parallel
+// calls to 'parallel_function' (and joins threads before returning).
+Status EmitCallToParallelForkJoin(
+ const std::vector<llvm::Value*>& arguments, const Shape& shape,
+ const std::vector<int64>& dimension_partition_counts,
+ llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function,
+ const string& name) {
+ llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
+
+ // Build ParallelForkJoin function type.
+ std::vector<llvm::Type*> compute_function_params =
+ GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
+ // Number of parallel compute functions.
+ compute_function_params.push_back(ir_builder->getInt32Ty());
+ // Array of partitions. There is an array element for each
+ // partition x partition_dim x 2 (for dimension start and limit).
+ compute_function_params.push_back(
+ llvm::Type::getInt64PtrTy(module->getContext()));
+ // Number of partitioned most-major dimensions in 'shape'.
+ compute_function_params.push_back(ir_builder->getInt32Ty());
+ // Function pointer for compute function to be dispatched in parallel.
+ compute_function_params.push_back(
+ llvm::Type::getInt8PtrTy(module->getContext()));
+
+ llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
+ /*Result=*/llvm::Type::getVoidTy(module->getContext()),
+ /*Params=*/compute_function_params,
+ /*isVarArg=*/false);
+
+ llvm::Function* fork_join_func =
+ llvm::cast<llvm::Function>(module->getOrInsertFunction(
+ runtime::kParallelForkJoinSymbolName, fork_join_type));
+ fork_join_func->setCallingConv(llvm::CallingConv::C);
+ fork_join_func->setDoesNotThrow();
+
+ // Add common compute function arguments.
+ std::vector<llvm::Value*> fork_join_arguments(arguments);
+
+ // Create ShapePartitionIterator to generate all partitions of 'shape'.
+ ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
+ const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
+ // Add argument specifying the number of parallel partitions.
+ fork_join_arguments.push_back(ir_builder->getInt32(num_partitions));
+
+ // The number of partitioned most-major dimensions in 'shape'.
+ const int32 num_partitioned_dims = dimension_partition_counts.size();
+ // A dimension partition consists of two elements: [start_index, limit_index).
+ const int32 dim_partition_size = 2;
+ // Calculate array partition stride.
+ const int32 array_partition_stride =
+ num_partitioned_dims * dim_partition_size;
+ // Calculate the total number of elements in the partition array.
+ const int32 partition_array_size =
+ dim_partition_size * num_partitioned_dims * num_partitions;
+
+ // Store dimension partition values as llvm constants in 'partitions'.
+ // See comments in runtime_fork_join.cc for array layout description.
+ std::vector<llvm::Constant*> partitions(partition_array_size);
+ for (int32 i = 0; i < num_partitions; ++i) {
+ std::vector<std::pair<int64, int64>> dim_partitions =
+ partition_iterator.GetPartition(i);
+ CHECK_EQ(num_partitioned_dims, dim_partitions.size());
+ const int32 partition_index = i * array_partition_stride;
+ for (int32 j = 0; j < num_partitioned_dims; ++j) {
+ const std::pair<int64, int64>& dim_partition = dim_partitions[j];
+ const int32 index = partition_index + j * dim_partition_size;
+ // Store partition [dim_start, dim_limit) intervals for each dimension.
+ partitions[index] = ir_builder->getInt64(dim_partition.first);
+ partitions[index + 1] =
+ ir_builder->getInt64(dim_partition.first + dim_partition.second);
+ }
+ }
+
+ // Create global variable out of dimension partitions in 'partitions'.
+ llvm::ArrayType* partitions_array_type =
+ llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size);
+ llvm::Constant* partitions_array =
+ llvm::ConstantArray::get(partitions_array_type, partitions);
+ llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
+ /*M=*/*module,
+ /*Ty=*/partitions_array_type,
+ /*isConstant=*/true,
+ /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
+ /*Initializer=*/partitions_array,
+ /*Name=*/
+ AsStringRef(
+ tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
+
+ // Add argument specifying parallel dimension partitions.
+ fork_join_arguments.push_back(ir_builder->CreateBitCast(
+ global_partitions_array,
+ llvm::Type::getInt64PtrTy(module->getContext())));
+ // Add argument specifying the number of partitioned most-major dimensions.
+ fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims));
+ // Add argument for parallel compute function pointer.
+ fork_join_arguments.push_back(
+ ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy()));
+ // Emit call to parallel fork/join.
+ ir_builder->CreateCall(fork_join_func, fork_join_arguments);
+
+ return Status::OK();
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index b7516b403e..1fd2da4dce 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -20,8 +20,11 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -54,12 +57,15 @@ class IrFunction {
llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds);
~IrFunction();
- // Returns an array of compute function parameter types.
- std::vector<llvm::Type*> GetComputeFunctionParams();
-
- // Emit ir to read and return the ir value for the dynamic loop bound at
- // 'offset' from the "dynamic_loop_bounds" argument of this function.
- llvm::Value* GetDynamicLoopBound(int64 offset);
+ // Emit ir to read and return the set of ir values representing the dynamic
+ // loop bounds argument of this function.
+ // Each element in returned vector is a pair of ir values representing
+ // the loop bounds for a specific dimension, where the first element of the
+ // pair is the dimension start index, and the second element of the pair
+ // is the dimension limit.
+ // EX: [dimension_i_index_start_ir_value, dimension_i_index_limit_ir_value]
+ //
+ DynamicLoopBounds GetDynamicLoopBounds();
// Returns the encapculated llvm::Function.
llvm::Function* function() { return function_; }
@@ -71,15 +77,15 @@ class IrFunction {
// "run_options" argument.
llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; }
- // Get the llvm::Argument that represents this functions parameters argument.
- llvm::Argument* parameters_arg() { return parameters_arg_; }
+ // Get the llvm::Value* that represents this functions parameters argument.
+ llvm::Value* parameters_arg() { return parameters_arg_; }
// Get the llvm::Value* that represents this functions "temps" argument.
llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; }
// Get the llvm::Value* that represents this functions "prof_counters"
// argument.
- llvm::Argument* profile_counters_arg() { return profile_counters_arg_; }
+ llvm::Value* profile_counters_arg() { return profile_counters_arg_; }
private:
// Initialize an llvm::Function with standard signature based on arguments.
@@ -87,6 +93,10 @@ class IrFunction {
llvm::Function::LinkageTypes linkage,
bool optimize_for_size_requested, bool enable_fast_math);
+ // Emit ir to read and return the ir value for the dynamic loop bound at
+ // 'offset' from the "dynamic_loop_bounds" argument of this function.
+ llvm::Value* GetDynamicLoopBound(int64 offset);
+
llvm::IRBuilder<>* ir_builder_;
llvm::Module* llvm_module_;
llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_;
@@ -97,12 +107,27 @@ class IrFunction {
// Function argument IR values.
llvm::Argument* result_arg_;
llvm::Value* exec_run_options_arg_;
- llvm::Argument* parameters_arg_;
+ llvm::Value* parameters_arg_;
llvm::Value* temp_buffers_arg_;
- llvm::Argument* dynamic_loop_bounds_arg_ = nullptr;
- llvm::Argument* profile_counters_arg_;
+ llvm::Value* dynamic_loop_bounds_arg_ = nullptr;
+ llvm::Value* profile_counters_arg_;
};
+// Returns an array of compute function call argument ir values.
+std::vector<llvm::Value*> GetArrayFunctionCallArguments(
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name,
+ llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
+ llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg);
+
+// Emits a call to a runtime fork/join function which dispatches parallel
+// calls to 'parallel_function' (and joins threads before returning).
+Status EmitCallToParallelForkJoin(
+ const std::vector<llvm::Value*>& arguments, const Shape& shape,
+ const std::vector<int64>& dimension_partition_counts,
+ llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function,
+ const string& name);
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
index 91e704e3d0..a3c3c1e5ef 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -24,8 +24,8 @@ namespace cpu {
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
- const llvm_ir::IrArray& target_array, const LoopBounds* dynamic_loop_bounds,
- llvm::IRBuilder<>* ir_builder)
+ const llvm_ir::IrArray& target_array,
+ const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder)
: LoopEmitter(target_element_generator, target_array, ir_builder),
dynamic_loop_bounds_(dynamic_loop_bounds) {}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index 492d5953c4..9335d2818e 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
@@ -31,9 +32,8 @@ namespace cpu {
// [start, limit) pairs of ir values (one for each partitioned outer dimension).
//
// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major
-// dimensions dynamic.
-// Then 'dynamic_loop_bounds' will contain the following ir values for
-// the two most-major dimenions:
+// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the
+// following ir values for the two most-major dimensions:
// [dim0_index_start_ir_value, dim0_index_limit_ir_value]
// [dim1_index_start_ir_value, dim1_index_limit_ir_value]
//
@@ -47,15 +47,13 @@ namespace cpu {
//
class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
public:
- using LoopBounds = std::vector<std::pair<llvm::Value*, llvm::Value*>>;
-
// Constructs a ParallelLoopEmitter which uses 'target_element_generator' to
// generate elements, 'dynamic_loop_bounds' to set the loop bounds of the
// most-major dimensions, and 'target_array.' shape to set the static loop
// bounds for the most-minor dimensions.
ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
const llvm_ir::IrArray& target_array,
- const LoopBounds* dynamic_loop_bounds,
+ const DynamicLoopBounds* dynamic_loop_bounds,
llvm::IRBuilder<>* ir_builder);
ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
@@ -66,7 +64,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
tensorflow::StringPiece loop_name) override;
private:
- const LoopBounds* dynamic_loop_bounds_;
+ const DynamicLoopBounds* dynamic_loop_bounds_;
};
} // namespace cpu