aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-11 14:45:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-11 14:49:20 -0700
commit10d0ae696c7b5618cae9e3845af8300fe62870a2 (patch)
tree03b971813653820bce5f366bdebd0aa01f687a73
parentf640c8980571d7578e891ea5ceab55978c8db9b4 (diff)
[XLA:CPU] Adds intra-op parallelism to the "sequential" CPU backend (which already has intra-op parallelism for library calls).
Adds support for parallel task assignment to instructions in entry (or embedded) computations. Adds code to emit calls to a new a runtime parallel fork/join function for instructions which have been assigned parallel tasks. Adds a simple cost model for I/O bound instructions. *) Translation (deleuze model) wall time (seconds). large_model small_model small_model_small_attn sequential: 0.00556 0.00484 0.00155 parallel: 0.00263 0.00163 0.00106 *) Wavenet sequential: Avg. latency (30 runs): 1026.13ms, min/max: 988/1108ms parallel: Avg. latency (30 runs): 800.633ms, min/max: 785/818ms *) ParallelFusion benchmark. Benchmark Time(ns) CPU(ns) Iterations ---------------------------------------------------------- sequential cpu backend (at head) 610584 611467 1000 parallel cpu backend 153241 836097 4528 sequential cpu backend (this CL) 113482 679535 6017 PiperOrigin-RevId: 171877766
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD18
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc15
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc20
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h1
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc192
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h21
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc148
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h49
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc93
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.h33
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc2
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc136
15 files changed, 644 insertions, 91 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 8ab358fe17..c71eca0d39 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -87,6 +87,7 @@ cc_library(
":ir_emitter",
":layout_assignment",
":parallel_cpu_executable",
+ ":parallel_task_assignment",
":simple_orc_jit",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
@@ -155,6 +156,7 @@ cc_library(
":disassembler",
":external_constant_pool",
":runtime_conv2d",
+ ":runtime_fork_join",
":runtime_matmul",
":runtime_single_threaded_conv2d",
":runtime_single_threaded_matmul",
@@ -243,6 +245,7 @@ cc_library(
":dot_op_emitter",
":external_constant_pool",
":ir_emission_utils",
+ ":shape_partition",
":simple_orc_jit",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -505,6 +508,20 @@ cc_library(
],
)
+cc_library(
+ name = "runtime_fork_join",
+ srcs = ["runtime_fork_join.cc"],
+ hdrs = ["runtime_fork_join.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//third_party/eigen3",
+ ],
+)
+
tf_cc_test(
name = "cpu_runtime_test",
srcs = ["cpu_runtime_test.cc"],
@@ -688,6 +705,7 @@ cc_library(
":shape_partition",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
+ "//tensorflow/compiler/xla/service:hlo_pass",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 386800d221..3272044faa 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -58,6 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h"
+#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@@ -248,7 +249,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
};
} // namespace
-Status CpuCompiler::RunHloPasses(HloModule* module) {
+Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// Optimization pipeline.
HloPassPipeline pipeline("CPU");
pipeline.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction());
@@ -316,6 +317,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
if (options::CpuParallelBackendRequested(module->config())) {
pipeline.AddPass<ParallelizationPreparation>(max_parallelism,
ShapeSizeBytesFunction());
+ } else if (!is_aot_compile) {
+ // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
+ // Note this is not run for AOT because it would bring in thread pool
+ // and thread synchronization dependencies which would likely increase
+ // binary size (and most AOT applications are single-threaded).
+ // TODO(29630486) Support multi-threaded AOT.
+ pipeline.AddPass<ParallelTaskAssigner>(max_parallelism,
+ ShapeSizeBytesFunction(), module);
}
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
@@ -450,7 +459,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
llvm_module->setDataLayout(jit->data_layout());
llvm_module->setTargetTriple(jit->target_triple().getTriple());
- TF_RETURN_IF_ERROR(RunHloPasses(module.get()));
+ TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false));
HloComputation* computation = module->entry_computation();
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
@@ -749,7 +758,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
HloModule* module = modules[i].get();
VLOG(1) << "Compiling ahead-of-time: " << module->name();
- TF_RETURN_IF_ERROR(RunHloPasses(module));
+ TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true));
TF_ASSIGN_OR_RETURN(
SequentialHloOrdering::HloModuleSequence module_sequence,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index bd3541500d..21dd128619 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -131,7 +131,7 @@ class CpuCompiler : public LLVMCompiler {
// Runs the HLO passes which are necessary for both optimizations and
// correctness.
- Status RunHloPasses(HloModule* module);
+ Status RunHloPasses(HloModule* module, bool is_aot_compile);
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
index 2cd0aa7880..662ee60923 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
@@ -116,26 +116,6 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
// Assign parallel tasks to HLOs in entry computation.
HloComputation* computation = module->entry_computation();
for (auto* instruction : computation->instructions()) {
- // Currently, we do not assign parallel tasks to instructions with at least
- // one of the following properties:
- // *) Internal threading (library calls to kConv, kDot, and kCustomCall).
- // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
- // *) Tuple-shaped.
- // TODO(b/27458679) Parallelize instructions which are skipped here.
- if (instruction->opcode() == HloOpcode::kParameter ||
- instruction->opcode() == HloOpcode::kConstant ||
- instruction->opcode() == HloOpcode::kCall ||
- instruction->opcode() == HloOpcode::kCustomCall ||
- instruction->opcode() == HloOpcode::kSelectAndScatter ||
- (instruction->opcode() == HloOpcode::kConvolution &&
- PotentiallyImplementedAsEigenConvolution(*instruction)) ||
- PotentiallyImplementedAsEigenDot(*instruction) ||
- (instruction->opcode() == HloOpcode::kFusion &&
- instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
- ShapeUtil::IsTuple(instruction->shape())) {
- continue;
- }
-
// Calculate target parallel task count in [1, max_parallelism_].
const int64 target_parallel_task_count =
parallel_task_assignment.GetTargetParallelTaskCount(instruction);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index c7155b858b..7908dc173d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -51,6 +51,9 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
"__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
+extern const char* const kParallelForkJoinSymbolName =
+ "__xla_cpu_runtime_ParallelForkJoin";
+
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index 29feb7267f..2ade455b8a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -51,6 +51,7 @@ extern const char* const kAcquireInfeedBufferForDequeueSymbolName;
extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
+extern const char* const kParallelForkJoinSymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 633ad0290c..c38325554f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -186,20 +187,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
// Even though the type of params and temps is void** in the host's view, in
// LLVM IR this is represented by i8*, similarly to void*. It's up to the code
// to use GEPs to unravel the indirection layers.
- llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
- llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
- llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext());
- std::vector<llvm::Type*> compute_function_params(
- {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
- if (IsParallelContext()) {
- compute_function_params.push_back(i64_ptr_type);
- }
- if (hlo_to_profile_idx_) {
- compute_function_params.push_back(i64_ptr_type);
- }
llvm::FunctionType* compute_function_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(module_->getContext()),
- /*Params=*/compute_function_params,
+ /*Params=*/GetComputeFunctionParams(),
/*isVarArg=*/false);
// Functions with local linkage get an inlining bonus. Because we know
@@ -221,7 +211,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
(++arg_iter)->setName("run_options");
(++arg_iter)->setName("params");
(++arg_iter)->setName("temps");
- if (IsParallelContext()) {
+ if (num_dynamic_loop_bounds_ > 0) {
(++arg_iter)->setName("dynamic_loop_bounds");
}
if (hlo_to_profile_idx_) {
@@ -2286,8 +2276,19 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
}
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
- EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
- emitted_value_[call], computation->name());
+
+ if (!computation->root_instruction()->outer_dimension_partitions().empty() &&
+ !parallel_cpu_backend_) {
+ // ParallelTaskAssignment assigned partitions, emit call to
+ // ParallelForkJoin.
+ TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses,
+ emitted_value_[call], computation,
+ call_ir_function));
+ } else {
+ EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
+ emitted_value_[call], computation->name());
+ }
+
return Status::OK();
}
@@ -2597,7 +2598,7 @@ Status IrEmitter::FinishVisit(HloInstruction* root) {
// For the parallel cpu backend, we record the total for each embedded
// computation callee with its caller kCall HLO.
HloInstruction* hlo_to_lookup = nullptr;
- if (IsParallelContext()) {
+ if (parallel_cpu_backend_ && is_top_level_computation_) {
auto* computation = root->parent();
auto* entry_computation = computation->parent()->entry_computation();
if (computation != entry_computation) {
@@ -2755,12 +2756,27 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
return llvm_ir::ShapeToIrType(shape, &ir_builder_);
}
+std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() {
+ llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
+ llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
+ llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext());
+ std::vector<llvm::Type*> compute_function_params(
+ {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
+ if (num_dynamic_loop_bounds_ > 0) {
+ compute_function_params.push_back(i64_ptr_type);
+ }
+ if (hlo_to_profile_idx_) {
+ compute_function_params.push_back(i64_ptr_type);
+ }
+ return compute_function_params;
+}
+
llvm::Argument* IrEmitter::GetResultArgument() {
return GetArg(compute_function_, 0);
}
llvm::Argument* IrEmitter::GetProfileCountersArgument() {
- const int64 arg_index = IsParallelContext() ? 5 : 4;
+ const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4;
return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr;
}
@@ -2843,18 +2859,11 @@ llvm::Value* IrEmitter::EmitElementFunctionCall(
AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
}
-// Emits a core function call based on the following pseudo-code.
-//
-// char** parameter_addresses_buffer =
-// allocate buffer with a pointer for each parameter to the function
-// for each parameter index, i.e. for i = 0, ..., #parameters:
-// parameter_addresses_buffer[i] = parameter_addresses[i]
-// call function(return_value_buffer,
-// parameter_addresses_buffer,
-// temps)
-// return return_value_buffer -- address of the return value.
-void IrEmitter::EmitArrayFunctionCallInto(
- llvm::Function* function,
+// Emits code to allocate an array of parameter address pointers, and store
+// each address from 'parameter_addresses'.
+// Returns an array of compute function call arguments (including parameter
+// address buffer).
+std::vector<llvm::Value*> IrEmitter::GetArrayFunctionCallArguments(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
llvm::Value* parameter_addresses_buffer =
@@ -2883,7 +2892,26 @@ void IrEmitter::EmitArrayFunctionCallInto(
if (auto* profile_counters = GetProfileCountersArgument()) {
arguments.push_back(profile_counters);
}
- ir_builder_.CreateCall(function, arguments);
+ return arguments;
+}
+
+// Emits a core function call based on the following pseudo-code.
+//
+// char** parameter_addresses_buffer =
+// allocate buffer with a pointer for each parameter to the function
+// for each parameter index, i.e. for i = 0, ..., #parameters:
+// parameter_addresses_buffer[i] = parameter_addresses[i]
+// call function(return_value_buffer,
+// parameter_addresses_buffer,
+// temps)
+// return return_value_buffer -- address of the return value.
+void IrEmitter::EmitArrayFunctionCallInto(
+ llvm::Function* function,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
+ ir_builder_.CreateCall(
+ function, GetArrayFunctionCallArguments(parameter_addresses,
+ return_value_buffer, name));
}
llvm::Value* IrEmitter::EmitArrayFunctionCall(
@@ -2903,6 +2931,110 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall(
return return_value_buffer;
}
+// Emits a call to a runtime fork/join function which dispatches parallel
+// calls to 'parallel_function' (and joins threads before returning).
+Status IrEmitter::EmitParallelForkJoin(
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* output_address, HloComputation* computation,
+ llvm::Function* parallel_function) {
+ HloInstruction* root = computation->root_instruction();
+
+ // Build ParallelForkJoin function type.
+ std::vector<llvm::Type*> compute_function_params = GetComputeFunctionParams();
+ // Number of parallel compute functions.
+ compute_function_params.push_back(ir_builder_.getInt32Ty());
+ // Array of partitions. There is an array element for each
+ // partition x partition_dim x 2 (for dimension start and limit).
+ compute_function_params.push_back(
+ llvm::Type::getInt64PtrTy(module_->getContext()));
+ // Number of partitioned most-major dimensions in 'root.shape'.
+ compute_function_params.push_back(ir_builder_.getInt32Ty());
+ // Function pointer for compute function to be dispatched in parallel.
+ compute_function_params.push_back(
+ llvm::Type::getInt8PtrTy(module_->getContext()));
+
+ llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
+ /*Result=*/llvm::Type::getVoidTy(module_->getContext()),
+ /*Params=*/compute_function_params,
+ /*isVarArg=*/false);
+
+ llvm::Function* fork_join_func =
+ llvm::cast<llvm::Function>(module_->getOrInsertFunction(
+ runtime::kParallelForkJoinSymbolName, fork_join_type));
+ fork_join_func->setCallingConv(llvm::CallingConv::C);
+ fork_join_func->setDoesNotThrow();
+
+ // Add common compute function arguments.
+ const string name = computation->name();
+ std::vector<llvm::Value*> arguments =
+ GetArrayFunctionCallArguments(parameter_addresses, output_address, name);
+
+ // Create ShapePartitionIterator to generate all partitions of 'root.shape'.
+ ShapePartitionIterator partition_iterator(root->shape(),
+ root->outer_dimension_partitions());
+ const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
+ // Add argument specifying the number of parallel partitions.
+ arguments.push_back(ir_builder_.getInt32(num_partitions));
+
+ // The number of partitioned most-major dimensions in 'root.shape'.
+ const int32 num_partitioned_dims = root->outer_dimension_partitions().size();
+ // A dimension partition consists of two elements: [start_index, limit_index).
+ const int32 dim_partition_size = 2;
+ // Calculate array partition stride.
+ const int32 array_partition_stride =
+ num_partitioned_dims * dim_partition_size;
+ // Calculate the total number of elements in the partition array.
+ const int32 partition_array_size =
+ dim_partition_size * num_partitioned_dims * num_partitions;
+
+ // Store dimension partition values as llvm constants in 'partitions'.
+ // See comments in runtime_fork_join.cc for array layout description.
+ std::vector<llvm::Constant*> partitions(partition_array_size);
+ for (int32 i = 0; i < num_partitions; ++i) {
+ std::vector<std::pair<int64, int64>> dim_partitions =
+ partition_iterator.GetPartition(i);
+ CHECK_EQ(num_partitioned_dims, dim_partitions.size());
+ const int32 partition_index = i * array_partition_stride;
+ for (int32 j = 0; j < num_partitioned_dims; ++j) {
+ const std::pair<int64, int64>& dim_partition = dim_partitions[j];
+ const int32 index = partition_index + j * dim_partition_size;
+ // Store partition [dim_start, dim_limit) intervals for each dimension.
+ partitions[index] = ir_builder_.getInt64(dim_partition.first);
+ partitions[index + 1] =
+ ir_builder_.getInt64(dim_partition.first + dim_partition.second);
+ }
+ }
+
+ // Create global variable out of dimension partitions in 'partitions'.
+ llvm::ArrayType* partitions_array_type =
+ llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size);
+ llvm::Constant* partitions_array =
+ llvm::ConstantArray::get(partitions_array_type, partitions);
+ llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
+ /*Module=*/*module_,
+ /*Type=*/partitions_array_type,
+ /*isConstant=*/true,
+ /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
+ /*Initializer=*/partitions_array,
+ /*Name=*/
+ AsStringRef(
+ tensorflow::strings::StrCat(name, "_parallel_dimension_partitions")));
+
+ // Add argument specifying parallel dimension partitions.
+ arguments.push_back(ir_builder_.CreateBitCast(
+ global_partitions_array,
+ llvm::Type::getInt64PtrTy(module_->getContext())));
+ // Add argument specifying the number of partitioned most-major dimensions.
+ arguments.push_back(ir_builder_.getInt32(num_partitioned_dims));
+ // Add argument for parallel compute function pointer.
+ arguments.push_back(
+ ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy()));
+ // Emit call to parallel fork/join.
+ ir_builder_.CreateCall(fork_join_func, arguments);
+
+ return Status::OK();
+}
+
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
llvm::Value* addr;
const Shape& target_shape = op->shape();
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 53c4b6f241..58c185af1e 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -249,6 +249,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Convenience function to get the IR type matching the given shape.
llvm::Type* IrShapeType(const Shape& shape);
+ // Returns an array of compute function parameter types.
+ std::vector<llvm::Type*> GetComputeFunctionParams();
+
// Get the llvm::Value* that represents the "retval" argument of the
// computation function being emitted by this emitter.
llvm::Argument* GetResultArgument();
@@ -323,6 +326,18 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name);
+ // Returns an array of compute function call arguments.
+ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value_buffer, tensorflow::StringPiece name);
+
+ // Emits a call to a runtime fork/join function which dispatches parallel
+ // calls to 'parallel_function' (and joins threads before returning).
+ Status EmitParallelForkJoin(
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* output_address, HloComputation* computation,
+ llvm::Function* parallel_function);
+
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
Status ElementTypesSameAndSupported(
@@ -596,12 +611,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value* program_buffer_address);
- // Returns true if the current function being emitted is called in a
- // parallel context (returns false otherwise).
- bool IsParallelContext() {
- return parallel_cpu_backend_ && is_top_level_computation_;
- }
-
const HloModuleConfig& hlo_module_config_;
const bool parallel_cpu_backend_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index d4b5e41f50..7219736b9e 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -48,29 +48,56 @@ class SimpleCostModel : public ParallelCostModel {
class DefaultCostModel : public ParallelCostModel {
public:
DefaultCostModel(const int64 max_parallelism,
+ const HloCostAnalysis::ShapeSizeFunction& shape_size,
std::unique_ptr<HloCostAnalysis> cost_analysis)
: max_parallelism_(max_parallelism),
+ shape_size_(shape_size),
cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel() override {}
int64 GetParallelTaskCount(HloInstruction* instruction) override {
- // Calculate the instruction cost in cycles.
- // TODO(29630486) Improve on this linear cost model.
- // Consider making 'min_cost_per_thread' be a function of the target
- // bandwidth limit for instructions with low arithmetic complexity.
- const int64 instruction_cost =
- 1 * cost_analysis_->flop_count(*instruction) +
- 2 * cost_analysis_->transcendental_count(*instruction) +
- 10 * cost_analysis_->bytes_accessed(*instruction);
- // Minimum per-thread cost is 100us of work on a 2GHz core.
- const int64 min_cost_per_thread = 100000;
+ // Parameters for parallel task count computation.
+ int64 instruction_cost;
+ int64 min_cost_per_thread;
+ int64 max_parallelism;
+ // Calculate flops-to-bytes-ratio for 'instruction'.
+ const int64 bytes_accessed =
+ std::max(1LL, cost_analysis_->bytes_accessed(*instruction));
+ const float flops_to_bytes_ratio =
+ cost_analysis_->flop_count(*instruction) /
+ static_cast<float>(bytes_accessed);
+ // Check for I/O bound instructions.
+ if (flops_to_bytes_ratio <= 1.0) {
+ // Limit max parallelism for I/O bound instructions by assuming a
+ // sub-linear scaling function (fit based on emperical benchmark results).
+ // TODO(29630486) Develop system bandwidth model.
+ max_parallelism =
+ std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs()));
+ // Use shape size instruction cost and L2 cache size min per-thread cost.
+ instruction_cost = shape_size_(instruction->shape());
+ min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
+ } else {
+ // Use max parallelism for compute bound instructions.
+ max_parallelism = max_parallelism_;
+ // Calculate the instruction cost in cycles.
+ // TODO(29630486) Improve on this linear cost model.
+ // Consider making 'min_cost_per_thread' be a function of the target
+ // bandwidth limit for instructions with low arithmetic complexity.
+ instruction_cost =
+ 1 * cost_analysis_->flop_count(*instruction) +
+ 2 * cost_analysis_->transcendental_count(*instruction) +
+ 10 * cost_analysis_->bytes_accessed(*instruction);
+ // Minimum per-thread cost is 100us of work on a 2GHz core.
+ min_cost_per_thread = 100000;
+ }
// Return target parallel task count in [1, max_parallelism_].
- return std::min(max_parallelism_,
+ return std::min(max_parallelism,
std::max(1LL, instruction_cost / min_cost_per_thread));
}
private:
const int64 max_parallelism_;
+ const HloCostAnalysis::ShapeSizeFunction shape_size_;
const std::unique_ptr<HloCostAnalysis> cost_analysis_;
};
@@ -86,7 +113,7 @@ ParallelTaskAssignment::ParallelTaskAssignment(
Status status = computation->root_instruction()->Accept(cost_analysis.get());
if (status.ok()) {
// Set default cost model based on 'cost_analysis'.
- cost_model_.reset(new DefaultCostModel(max_parallelism,
+ cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size,
std::move(cost_analysis)));
} else {
// Fall back to a simple cost model based on hlo size and L2 cache size.
@@ -121,5 +148,102 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
return cost_model_->GetParallelTaskCount(instruction);
}
+StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
+ XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
+ XLA_VLOG_LINES(3, module->ToString());
+
+ // Compute target parallel task counts for all instructions in 'module'.
+ HloToParallelTasks hlo_to_parallel_tasks;
+ ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
+
+ // Assign parallel tasks to target specific instructions in 'module'.
+ // TODO(b/27458679) Support inter-op parallelism.
+ bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks);
+
+ XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT");
+ XLA_VLOG_LINES(3, module->ToString());
+ return changed;
+}
+
+bool ParallelTaskAssigner::AssignParallelTasks(
+ HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) {
+ return AssignParallelTasksHelper(module, module->entry_computation(),
+ hlo_to_parallel_tasks);
+}
+
+bool ParallelTaskAssigner::AssignParallelTasksHelper(
+ HloModule* module, HloComputation* computation,
+ const HloToParallelTasks& hlo_to_parallel_tasks) {
+ bool changed = false;
+ // Snapshot set of instructions because outlining modifies the set below.
+ std::vector<HloInstruction*> instructions(computation->instructions().begin(),
+ computation->instructions().end());
+ for (auto* instruction : instructions) {
+ // Assign parallel tasks to sub-computations for While and Call HLOs.
+ // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement,
+ // and support other callable computations like reduce.
+ if (instruction->opcode() == HloOpcode::kWhile) {
+ changed |= AssignParallelTasksHelper(module, instruction->while_body(),
+ hlo_to_parallel_tasks);
+ continue;
+ } else if (instruction->opcode() == HloOpcode::kCall) {
+ changed |= AssignParallelTasksHelper(module, instruction->to_apply(),
+ hlo_to_parallel_tasks);
+ continue;
+ }
+ // Skip if no parallel tasks were computed in first pass.
+ auto it = hlo_to_parallel_tasks.find(instruction);
+ if (it == hlo_to_parallel_tasks.end()) {
+ continue;
+ }
+ // Get target parallel task count computed for 'instruction'.
+ const int64 target_parallel_task_count = (*it).second;
+ // Assign feasible dimension partitions (based on actual dimension sizes).
+ auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
+ .Run(target_parallel_task_count);
+ const int64 total_partition_count =
+ ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
+ if (total_partition_count <= 1) {
+ // Feasible partition calculation resulting in no partitioning, so skip.
+ continue;
+ }
+
+ // Outline 'instruction' in 'computation' for parallel task assignment.
+ auto* call = module->OutlineExpressionFromComputation(
+ {instruction},
+ tensorflow::strings::StrCat("parallel_", instruction->name()),
+ computation);
+
+ // Set assigned dimension partitioning to 'instruction'.
+ auto* new_root = call->to_apply()->root_instruction();
+ new_root->set_outer_dimension_partitions(dim_partition_counts);
+
+ VLOG(2) << "Assigned parallel task count: " << total_partition_count
+ << " to instruction: " << new_root->name()
+ << " parent: " << new_root->parent()->name();
+ changed = true;
+ }
+ return changed;
+}
+
+void ParallelTaskAssigner::ComputeTargetParallelTasks(
+ HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
+ // Compute parallel task counts for all instructions in 'module'.
+ for (auto* computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
+ for (auto* instruction : computation->instructions()) {
+ // Query ParallelTaskAssignment for target parallel task count.
+ const int64 target_parallel_task_count =
+ parallel_task_assignment_.GetTargetParallelTaskCount(instruction);
+ if (target_parallel_task_count > 1) {
+ hlo_to_parallel_tasks->insert(
+ {instruction, target_parallel_task_count});
+ }
+ }
+ }
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index 15f065a3ad..e036da5784 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace cpu {
@@ -49,6 +50,54 @@ class ParallelTaskAssignment {
std::unique_ptr<ParallelCostModel> cost_model_;
};
+// ParallelTaskAssigner computes target parallel task counts for all HLOs
+// in the module, then assigns parallel task counts to HLOs in the entry
+// computation, or to HLOs in embedded computations invoked by (potentially
+// nested) kWhile or kCall instructions.
+// Each HLO which is assigned parallel task counts is outlined into its
+// own embedded computation, which is compiled as a parallel compute function,
+// and which is invoked from a kCall instruction that is lowered in codegen to
+// a runtime parallel fork/join call.
+class ParallelTaskAssigner : public HloPassInterface {
+ public:
+ // 'max_parallelism': the maximum parallel task count per instruction.
+ // 'shape_size': shape size function used by HloCostAnalysis during parallel
+ // task assignment.
+ // 'module': the containing HloModule.
+ ParallelTaskAssigner(const int64 max_parallelism,
+ const HloCostAnalysis::ShapeSizeFunction& shape_size,
+ HloModule* module)
+ : parallel_task_assignment_(max_parallelism, shape_size, module) {}
+ ~ParallelTaskAssigner() override {}
+
+ tensorflow::StringPiece name() const override {
+ return "cpu-parallel-task-assigner";
+ }
+
+ // Run parallel task assigner on 'module'.
+ // Returns true if the computation was changed, false otherwise.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ using HloToParallelTasks = std::unordered_map<const HloInstruction*, int64>;
+
+ // Assigns target parallel tasks from 'hlo_to_parallel_tasks' to HLOs in
+ // 'module'.
+ // Returns true if the computation was changed, false otherwise.
+ bool AssignParallelTasks(HloModule* module,
+ const HloToParallelTasks& hlo_to_parallel_tasks);
+ bool AssignParallelTasksHelper(
+ HloModule* module, HloComputation* computation,
+ const HloToParallelTasks& hlo_to_parallel_tasks);
+
+ // Computes target parallel task counts (returned in 'parallel_task_counts')
+ // for parallelizable instructions in 'module'.
+ void ComputeTargetParallelTasks(HloModule* module,
+ HloToParallelTasks* hlo_to_parallel_tasks);
+
+ ParallelTaskAssignment parallel_task_assignment_;
+};
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
new file mode 100644
index 0000000000..af2f3de6b8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/platform/logging.h"
+
+using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
+ int64*, uint64*);
+
+// Dispatches 'num_partitions - 1' calls to 'function_ptr' in parallel.
+// Calls 'function_ptr' for first partition inline.
+// Uses blocking counter to synchonize threads after parallel calls complete.
+//
+// The 'partitions' array has a total number of elements equal to
+// 'num_partitions * num_partitioned_dims * 2' (the '2' is necessary to specify
+// dimension start and limit indices).
+//
+// The 'partitions' array layout stores array elements in memory with dimension
+// start limit as the most-minor dimension, followed by dimension, then
+// partition.
+//
+// EX: Layout of 'partitions' array with 'num_partitions = 2', and
+// 'num_partitioned_dims = 3'
+//
+// [partition0_dim0_start]
+// [partition0_dim0_limit]
+// [partition0_dim1_start]
+// [partition0_dim1_limit]
+// [partition0_dim2_start]
+// [partition0_dim2_limit]
+// [partition1_dim0_start]
+// [partition1_dim0_limit]
+// [partition1_dim1_start]
+// [partition1_dim1_limit]
+// [partition1_dim2_start]
+// [partition1_dim2_limit]
+//
+void __xla_cpu_runtime_ParallelForkJoin(
+ void* result_ptr, const void* run_options_ptr, const void** params,
+ void** temps, uint64* prof_counters, tensorflow::int32 num_partitions,
+ tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims,
+ void* function_ptr) {
+ VLOG(2) << "ParallelForkJoin ENTRY"
+ << " num_partitions: " << num_partitions
+ << " num_partitioned_dims: " << num_partitioned_dims;
+ CHECK_GT(num_partitions, 1);
+ CHECK_GT(num_partitioned_dims, 0);
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+ ComputeFunctionType function =
+ reinterpret_cast<ComputeFunctionType>(function_ptr);
+ // Compute partition stride in 'partitions' array.
+ const int64 stride = 2 * num_partitioned_dims;
+
+ // Dispatch 'num_partitions - 1' compute functions to run in parallel.
+ tensorflow::BlockingCounter bc(num_partitions - 1);
+ for (tensorflow::int32 i = 1; i < num_partitions; ++i) {
+ const int64 offset = i * stride;
+ run_options->intra_op_thread_pool()->enqueue_function(
+ [i, function, result_ptr, run_options_ptr, params, temps, prof_counters,
+ partitions, offset, &bc]() {
+ function(result_ptr, run_options_ptr, params, temps,
+ &partitions[offset], prof_counters);
+ bc.DecrementCount();
+ VLOG(3) << "ParallelForkJoin partition " << i << " done.";
+ });
+ }
+
+ // Call first compute function inline.
+ function(result_ptr, run_options_ptr, params, temps, &partitions[0],
+ prof_counters);
+ VLOG(3) << "ParallelForkJoin partition 0 done.";
+ bc.Wait();
+ VLOG(2) << "ParallelForkJoin EXIT";
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
new file mode 100644
index 0000000000..1ddcaf5274
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// Dispatches 'num_partitions' parallel calls to 'function_ptr' and joins
+// threads before returning. See comments in runtime_fork_join.cc for details.
+extern void __xla_cpu_runtime_ParallelForkJoin(
+ void* result_ptr, const void* run_options_ptr, const void** params,
+ void** temps, uint64* prof_counters, tensorflow::int32 num_partitions,
+ tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims,
+ void* function_ptr);
+
+} // extern "C"
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index c614e334a8..cfffb3fbc3 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
@@ -104,6 +105,7 @@ class JITSymbolTable {
ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32);
ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32);
ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64);
+ ADD_JIT_SYMBOL_TO_TABLE(ParallelForkJoin);
#undef ADD_JIT_SYMBOL_TO_TABLE
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index f37a331a72..256ec71ab5 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1410,8 +1410,10 @@ xla_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "//third_party/eigen3",
],
)
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 3bf9ccb197..a8f6488996 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -17,8 +17,12 @@ limitations under the License.
#include <algorithm>
#include <memory>
#include <new>
+#include <random>
#include <utility>
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
@@ -37,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -250,6 +255,42 @@ XLA_TEST_F(FusionTest, Parameter) {
ErrorSpec(1e-4));
}
+XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
+ // Tests parallel partitioning of a fusion instruction.
+ // Create shape with random outer dimension size to generate random parallel
+ // partition counts for each test run.
+ const int seed = tensorflow::testing::RandomSeed();
+ LOG(INFO) << "RandomizedParallelPartition seed: " << seed;
+ std::mt19937 generator(seed);
+ std::uniform_int_distribution<int> distribution(128, 1024);
+ const int64 rand_dim0_size = distribution(generator);
+ const int64 dim1_size = 1024;
+ Shape shape =
+ ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0});
+ // Build simple fusion computation: y = x^2 (elementwise).
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = CreateNewModule();
+
+ auto two = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x));
+
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two},
+ HloInstruction::FusionKind::kLoop);
+ // Compute result.
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ // Every element of result should be y = x^2 = 4.0.
+ for (int i = 0; i < rand_dim0_size; ++i) {
+ for (int j = 0; j < dim1_size; ++j) {
+ EXPECT_EQ(4.0, result->Get<float>({i, j}));
+ }
+ }
+}
+
XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
@@ -722,47 +763,104 @@ void BM_ParallelFusion(int num_iters) {
auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
StreamExecutorMemoryAllocator allocator(platform, executors);
- const int64 intra_op_parallelism_threads = 16;
+ const int64 intra_op_parallelism_threads = 24;
xla::LocalClientOptions client_options;
client_options.set_platform(platform);
client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads);
auto client =
ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
- const int64 dim_size = 1024;
- // Create a simple fusable elementwise computation.
+ auto* transfer_manager =
+ TransferManager::GetForPlatform(platform).ValueOrDie();
+ int device_ordinal = client->default_device_ordinal();
+
+ // Computation shape parameters.
+ const int64 param0_dim0 = 1024;
+ const int64 param0_dim1 = 1024;
+ const int64 param1_dim0 = 1024;
+ const int64 param1_dim1 = 1024;
+ const int64 param2_dim0 = 1024;
+ const int64 param2_dim1 = 1024;
+
+ // Create computation.
ComputationBuilder builder(client, "ParallelFusion");
- Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size});
- auto input0 = builder.Broadcast(builder.ConstantR0<float>(1.5f),
- AsInt64Slice(input_shape.dimensions()));
- auto input1 = builder.Broadcast(builder.ConstantR0<float>(2.0f),
- AsInt64Slice(input_shape.dimensions()));
- auto input2 = builder.Broadcast(builder.ConstantR0<float>(3.0f),
- AsInt64Slice(input_shape.dimensions()));
- auto x = builder.Mul(input0, input1);
- auto y = builder.Add(x, input2);
+ Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1});
+ auto param0 = builder.Parameter(0, shape0, "param0");
+ Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1});
+ auto param1 = builder.Parameter(1, shape1, "param1");
+ Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1});
+ auto param2 = builder.Parameter(2, shape2, "param2");
+
+ auto x = builder.Mul(param0, param1);
+ auto y = builder.Add(x, param2);
auto computation = builder.Build().ConsumeValueOrDie();
+ // Transfer literals to device.
+ auto buffer0 =
+ ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0)
+ .ConsumeValueOrDie();
+ auto param0_literal =
+ Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
+ executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({})));
+
+ auto buffer1 =
+ ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0)
+ .ConsumeValueOrDie();
+ auto param1_literal =
+ Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
+ executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({})));
+
+ auto buffer2 =
+ ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0)
+ .ConsumeValueOrDie();
+ auto param2_literal =
+ Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
+ executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({})));
+
+ // Build executable.
std::unique_ptr<LocalExecutable> executable =
- client->Compile(computation, {}, ExecutableBuildOptions())
+ client
+ ->Compile(computation,
+ {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()},
+ ExecutableBuildOptions())
.ConsumeValueOrDie();
- // Run some warm-up executions.
+ se::Stream stream(executors[client->default_device_ordinal()]);
+ stream.Init();
+
+ // Initialize thread pool.
+ tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
+ intra_op_parallelism_threads);
+ tensorflow::EigenThreadPoolWrapper tp(&pool);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ // Initialize ExecutableRunOptions.
ExecutableRunOptions options;
- options.set_allocator(&allocator);
+ options.set_allocator(&allocator).set_stream(&stream);
+ options.set_intra_op_thread_pool(&device);
+
+ // Run some warm-up executions.
const int kWarmups = 2;
for (int i = 0; i < kWarmups; ++i) {
- auto result = executable->Run({}, options);
+ auto result =
+ executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options);
ASSERT_TRUE(result.ok());
}
// Run benchmark.
- tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) * dim_size *
- dim_size * sizeof(float));
+ const int64 total_bytes = param0_dim0 * param0_dim0 +
+ param1_dim0 * param1_dim0 +
+ param2_dim0 * param2_dim0;
+ tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) *
+ total_bytes * sizeof(float));
tensorflow::testing::UseRealTime();
tensorflow::testing::StartTiming();
for (int i = 0; i < num_iters; ++i) {
- auto result = executable->Run({}, options);
+ auto result =
+ executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options);
ASSERT_TRUE(result.ok());
}
}