diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-11 14:45:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-11 14:49:20 -0700 |
commit | 10d0ae696c7b5618cae9e3845af8300fe62870a2 (patch) | |
tree | 03b971813653820bce5f366bdebd0aa01f687a73 | |
parent | f640c8980571d7578e891ea5ceab55978c8db9b4 (diff) |
[XLA:CPU] Adds intra-op parallelism to the "sequential" CPU backend (which already has intra-op parallelism for library calls).
Adds support for parallel task assignment to instructions in entry (or embedded) computations.
Adds code to emit calls to a new a runtime parallel fork/join function for instructions which have been assigned parallel tasks.
Adds a simple cost model for I/O bound instructions.
*) Translation (deleuze model) wall time (seconds).
large_model small_model small_model_small_attn
sequential: 0.00556 0.00484 0.00155
parallel: 0.00263 0.00163 0.00106
*) Wavenet
sequential: Avg. latency (30 runs): 1026.13ms, min/max: 988/1108ms
parallel: Avg. latency (30 runs): 800.633ms, min/max: 785/818ms
*) ParallelFusion benchmark.
Benchmark Time(ns) CPU(ns) Iterations
----------------------------------------------------------
sequential cpu backend (at head) 610584 611467 1000
parallel cpu backend 153241 836097 4528
sequential cpu backend (this CL) 113482 679535 6017
PiperOrigin-RevId: 171877766
15 files changed, 644 insertions, 91 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 8ab358fe17..c71eca0d39 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -87,6 +87,7 @@ cc_library( ":ir_emitter", ":layout_assignment", ":parallel_cpu_executable", + ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -155,6 +156,7 @@ cc_library( ":disassembler", ":external_constant_pool", ":runtime_conv2d", + ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", @@ -243,6 +245,7 @@ cc_library( ":dot_op_emitter", ":external_constant_pool", ":ir_emission_utils", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -505,6 +508,20 @@ cc_library( ], ) +cc_library( + name = "runtime_fork_join", + srcs = ["runtime_fork_join.cc"], + hdrs = ["runtime_fork_join.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], @@ -688,6 +705,7 @@ cc_library( ":shape_partition", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/compiler/xla/service:hlo_pass", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 386800d221..3272044faa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -58,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -248,7 +249,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction()); @@ -316,6 +317,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { if (options::CpuParallelBackendRequested(module->config())) { pipeline.AddPass<ParallelizationPreparation>(max_parallelism, ShapeSizeBytesFunction()); + } else if (!is_aot_compile) { + // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. + // Note this is not run for AOT because it would bring in thread pool + // and thread synchronization dependencies which would likely increase + // binary size (and most AOT applications are single-threaded). + // TODO(29630486) Support multi-threaded AOT. + pipeline.AddPass<ParallelTaskAssigner>(max_parallelism, + ShapeSizeBytesFunction(), module); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -450,7 +459,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get())); + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); HloComputation* computation = module->entry_computation(); std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx; @@ -749,7 +758,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, HloModule* module = modules[i].get(); VLOG(1) << "Compiling ahead-of-time: " << module->name(); - TF_RETURN_IF_ERROR(RunHloPasses(module)); + TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index bd3541500d..21dd128619 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -131,7 +131,7 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module); + Status RunHloPasses(HloModule* module, bool is_aot_compile); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index 2cd0aa7880..662ee60923 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -116,26 +116,6 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment( // Assign parallel tasks to HLOs in entry computation. HloComputation* computation = module->entry_computation(); for (auto* instruction : computation->instructions()) { - // Currently, we do not assign parallel tasks to instructions with at least - // one of the following properties: - // *) Internal threading (library calls to kConv, kDot, and kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). - // *) Tuple-shaped. - // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || - ShapeUtil::IsTuple(instruction->shape())) { - continue; - } - // Calculate target parallel task count in [1, max_parallelism_]. const int64 target_parallel_task_count = parallel_task_assignment.GetTargetParallelTaskCount(instruction); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index c7155b858b..7908dc173d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -51,6 +51,9 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName = "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation"; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; +extern const char* const kParallelForkJoinSymbolName = + "__xla_cpu_runtime_ParallelForkJoin"; + extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 29feb7267f..2ade455b8a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -51,6 +51,7 @@ extern const char* const kAcquireInfeedBufferForDequeueSymbolName; extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; +extern const char* const kParallelForkJoinSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 633ad0290c..c38325554f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -186,20 +187,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { // Even though the type of params and temps is void** in the host's view, in // LLVM IR this is represented by i8*, similarly to void*. It's up to the code // to use GEPs to unravel the indirection layers. - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector<llvm::Type*> compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (IsParallelContext()) { - compute_function_params.push_back(i64_ptr_type); - } - if (hlo_to_profile_idx_) { - compute_function_params.push_back(i64_ptr_type); - } llvm::FunctionType* compute_function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, + /*Params=*/GetComputeFunctionParams(), /*isVarArg=*/false); // Functions with local linkage get an inlining bonus. Because we know @@ -221,7 +211,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { (++arg_iter)->setName("run_options"); (++arg_iter)->setName("params"); (++arg_iter)->setName("temps"); - if (IsParallelContext()) { + if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); } if (hlo_to_profile_idx_) { @@ -2286,8 +2276,19 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - emitted_value_[call], computation->name()); + + if (!computation->root_instruction()->outer_dimension_partitions().empty() && + !parallel_cpu_backend_) { + // ParallelTaskAssignment assigned partitions, emit call to + // ParallelForkJoin. + TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, + emitted_value_[call], computation, + call_ir_function)); + } else { + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, + emitted_value_[call], computation->name()); + } + return Status::OK(); } @@ -2597,7 +2598,7 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // For the parallel cpu backend, we record the total for each embedded // computation callee with its caller kCall HLO. HloInstruction* hlo_to_lookup = nullptr; - if (IsParallelContext()) { + if (parallel_cpu_backend_ && is_top_level_computation_) { auto* computation = root->parent(); auto* entry_computation = computation->parent()->entry_computation(); if (computation != entry_computation) { @@ -2755,12 +2756,27 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, &ir_builder_); } +std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); + std::vector<llvm::Type*> compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } + if (hlo_to_profile_idx_) { + compute_function_params.push_back(i64_ptr_type); + } + return compute_function_params; +} + llvm::Argument* IrEmitter::GetResultArgument() { return GetArg(compute_function_, 0); } llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = IsParallelContext() ? 5 : 4; + const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; } @@ -2843,18 +2859,11 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector<llvm::Value*> IrEmitter::GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { llvm::Value* parameter_addresses_buffer = @@ -2883,7 +2892,26 @@ void IrEmitter::EmitArrayFunctionCallInto( if (auto* profile_counters = GetProfileCountersArgument()) { arguments.push_back(profile_counters); } - ir_builder_.CreateCall(function, arguments); + return arguments; +} + +// Emits a core function call based on the following pseudo-code. +// +// char** parameter_addresses_buffer = +// allocate buffer with a pointer for each parameter to the function +// for each parameter index, i.e. for i = 0, ..., #parameters: +// parameter_addresses_buffer[i] = parameter_addresses[i] +// call function(return_value_buffer, +// parameter_addresses_buffer, +// temps) +// return return_value_buffer -- address of the return value. +void IrEmitter::EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name) { + ir_builder_.CreateCall( + function, GetArrayFunctionCallArguments(parameter_addresses, + return_value_buffer, name)); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -2903,6 +2931,110 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status IrEmitter::EmitParallelForkJoin( + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function) { + HloInstruction* root = computation->root_instruction(); + + // Build ParallelForkJoin function type. + std::vector<llvm::Type*> compute_function_params = GetComputeFunctionParams(); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module_->getContext())); + // Number of partitioned most-major dimensions in 'root.shape'. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module_->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module_->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast<llvm::Function>(module_->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + const string name = computation->name(); + std::vector<llvm::Value*> arguments = + GetArrayFunctionCallArguments(parameter_addresses, output_address, name); + + // Create ShapePartitionIterator to generate all partitions of 'root.shape'. + ShapePartitionIterator partition_iterator(root->shape(), + root->outer_dimension_partitions()); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + arguments.push_back(ir_builder_.getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'root.shape'. + const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector<llvm::Constant*> partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector<std::pair<int64, int64>> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair<int64, int64>& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder_.getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder_.getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + arguments.push_back(ir_builder_.CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module_->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + arguments.push_back( + ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder_.CreateCall(fork_join_func, arguments); + + return Status::OK(); +} + Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { llvm::Value* addr; const Shape& target_shape = op->shape(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 53c4b6f241..58c185af1e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -249,6 +249,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); + // Returns an array of compute function parameter types. + std::vector<llvm::Type*> GetComputeFunctionParams(); + // Get the llvm::Value* that represents the "retval" argument of the // computation function being emitted by this emitter. llvm::Argument* GetResultArgument(); @@ -323,6 +326,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, tensorflow::StringPiece name); + // Returns an array of compute function call arguments. + std::vector<llvm::Value*> GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name); + + // Emits a call to a runtime fork/join function which dispatches parallel + // calls to 'parallel_function' (and joins threads before returning). + Status EmitParallelForkJoin( + tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function); + // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -596,12 +611,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - // Returns true if the current function being emitted is called in a - // parallel context (returns false otherwise). - bool IsParallelContext() { - return parallel_cpu_backend_ && is_top_level_computation_; - } - const HloModuleConfig& hlo_module_config_; const bool parallel_cpu_backend_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index d4b5e41f50..7219736b9e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -48,29 +48,56 @@ class SimpleCostModel : public ParallelCostModel { class DefaultCostModel : public ParallelCostModel { public: DefaultCostModel(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, std::unique_ptr<HloCostAnalysis> cost_analysis) : max_parallelism_(max_parallelism), + shape_size_(shape_size), cost_analysis_(std::move(cost_analysis)) {} ~DefaultCostModel() override {} int64 GetParallelTaskCount(HloInstruction* instruction) override { - // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. - // Consider making 'min_cost_per_thread' be a function of the target - // bandwidth limit for instructions with low arithmetic complexity. - const int64 instruction_cost = - 1 * cost_analysis_->flop_count(*instruction) + - 2 * cost_analysis_->transcendental_count(*instruction) + - 10 * cost_analysis_->bytes_accessed(*instruction); - // Minimum per-thread cost is 100us of work on a 2GHz core. - const int64 min_cost_per_thread = 100000; + // Parameters for parallel task count computation. + int64 instruction_cost; + int64 min_cost_per_thread; + int64 max_parallelism; + // Calculate flops-to-bytes-ratio for 'instruction'. + const int64 bytes_accessed = + std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + const float flops_to_bytes_ratio = + cost_analysis_->flop_count(*instruction) / + static_cast<float>(bytes_accessed); + // Check for I/O bound instructions. + if (flops_to_bytes_ratio <= 1.0) { + // Limit max parallelism for I/O bound instructions by assuming a + // sub-linear scaling function (fit based on emperical benchmark results). + // TODO(29630486) Develop system bandwidth model. + max_parallelism = + std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); + // Use shape size instruction cost and L2 cache size min per-thread cost. + instruction_cost = shape_size_(instruction->shape()); + min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + } else { + // Use max parallelism for compute bound instructions. + max_parallelism = max_parallelism_; + // Calculate the instruction cost in cycles. + // TODO(29630486) Improve on this linear cost model. + // Consider making 'min_cost_per_thread' be a function of the target + // bandwidth limit for instructions with low arithmetic complexity. + instruction_cost = + 1 * cost_analysis_->flop_count(*instruction) + + 2 * cost_analysis_->transcendental_count(*instruction) + + 10 * cost_analysis_->bytes_accessed(*instruction); + // Minimum per-thread cost is 100us of work on a 2GHz core. + min_cost_per_thread = 100000; + } // Return target parallel task count in [1, max_parallelism_]. - return std::min(max_parallelism_, + return std::min(max_parallelism, std::max(1LL, instruction_cost / min_cost_per_thread)); } private: const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; const std::unique_ptr<HloCostAnalysis> cost_analysis_; }; @@ -86,7 +113,7 @@ ParallelTaskAssignment::ParallelTaskAssignment( Status status = computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { // Set default cost model based on 'cost_analysis'. - cost_model_.reset(new DefaultCostModel(max_parallelism, + cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size, std::move(cost_analysis))); } else { // Fall back to a simple cost model based on hlo size and L2 cache size. @@ -121,5 +148,102 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( return cost_model_->GetParallelTaskCount(instruction); } +StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); + XLA_VLOG_LINES(3, module->ToString()); + + // Compute target parallel task counts for all instructions in 'module'. + HloToParallelTasks hlo_to_parallel_tasks; + ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); + + // Assign parallel tasks to target specific instructions in 'module'. + // TODO(b/27458679) Support inter-op parallelism. + bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks); + + XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT"); + XLA_VLOG_LINES(3, module->ToString()); + return changed; +} + +bool ParallelTaskAssigner::AssignParallelTasks( + HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) { + return AssignParallelTasksHelper(module, module->entry_computation(), + hlo_to_parallel_tasks); +} + +bool ParallelTaskAssigner::AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks) { + bool changed = false; + // Snapshot set of instructions because outlining modifies the set below. + std::vector<HloInstruction*> instructions(computation->instructions().begin(), + computation->instructions().end()); + for (auto* instruction : instructions) { + // Assign parallel tasks to sub-computations for While and Call HLOs. + // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement, + // and support other callable computations like reduce. + if (instruction->opcode() == HloOpcode::kWhile) { + changed |= AssignParallelTasksHelper(module, instruction->while_body(), + hlo_to_parallel_tasks); + continue; + } else if (instruction->opcode() == HloOpcode::kCall) { + changed |= AssignParallelTasksHelper(module, instruction->to_apply(), + hlo_to_parallel_tasks); + continue; + } + // Skip if no parallel tasks were computed in first pass. + auto it = hlo_to_parallel_tasks.find(instruction); + if (it == hlo_to_parallel_tasks.end()) { + continue; + } + // Get target parallel task count computed for 'instruction'. + const int64 target_parallel_task_count = (*it).second; + // Assign feasible dimension partitions (based on actual dimension sizes). + auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) + .Run(target_parallel_task_count); + const int64 total_partition_count = + ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); + if (total_partition_count <= 1) { + // Feasible partition calculation resulting in no partitioning, so skip. + continue; + } + + // Outline 'instruction' in 'computation' for parallel task assignment. + auto* call = module->OutlineExpressionFromComputation( + {instruction}, + tensorflow::strings::StrCat("parallel_", instruction->name()), + computation); + + // Set assigned dimension partitioning to 'instruction'. + auto* new_root = call->to_apply()->root_instruction(); + new_root->set_outer_dimension_partitions(dim_partition_counts); + + VLOG(2) << "Assigned parallel task count: " << total_partition_count + << " to instruction: " << new_root->name() + << " parent: " << new_root->parent()->name(); + changed = true; + } + return changed; +} + +void ParallelTaskAssigner::ComputeTargetParallelTasks( + HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { + // Compute parallel task counts for all instructions in 'module'. + for (auto* computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto* instruction : computation->instructions()) { + // Query ParallelTaskAssignment for target parallel task count. + const int64 target_parallel_task_count = + parallel_task_assignment_.GetTargetParallelTaskCount(instruction); + if (target_parallel_task_count > 1) { + hlo_to_parallel_tasks->insert( + {instruction, target_parallel_task_count}); + } + } + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 15f065a3ad..e036da5784 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace cpu { @@ -49,6 +50,54 @@ class ParallelTaskAssignment { std::unique_ptr<ParallelCostModel> cost_model_; }; +// ParallelTaskAssigner computes target parallel task counts for all HLOs +// in the module, then assigns parallel task counts to HLOs in the entry +// computation, or to HLOs in embedded computations invoked by (potentially +// nested) kWhile or kCall instructions. +// Each HLO which is assigned parallel task counts is outlined into its +// own embedded computation, which is compiled as a parallel compute function, +// and which is invoked from a kCall instruction that is lowered in codegen to +// a runtime parallel fork/join call. +class ParallelTaskAssigner : public HloPassInterface { + public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + // 'module': the containing HloModule. + ParallelTaskAssigner(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module) + : parallel_task_assignment_(max_parallelism, shape_size, module) {} + ~ParallelTaskAssigner() override {} + + tensorflow::StringPiece name() const override { + return "cpu-parallel-task-assigner"; + } + + // Run parallel task assigner on 'module'. + // Returns true if the computation was changed, false otherwise. + StatusOr<bool> Run(HloModule* module) override; + + private: + using HloToParallelTasks = std::unordered_map<const HloInstruction*, int64>; + + // Assigns target parallel tasks from 'hlo_to_parallel_tasks' to HLOs in + // 'module'. + // Returns true if the computation was changed, false otherwise. + bool AssignParallelTasks(HloModule* module, + const HloToParallelTasks& hlo_to_parallel_tasks); + bool AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks); + + // Computes target parallel task counts (returned in 'parallel_task_counts') + // for parallelizable instructions in 'module'. + void ComputeTargetParallelTasks(HloModule* module, + HloToParallelTasks* hlo_to_parallel_tasks); + + ParallelTaskAssignment parallel_task_assignment_; +}; + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc new file mode 100644 index 0000000000..af2f3de6b8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/logging.h" + +using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, + int64*, uint64*); + +// Dispatches 'num_partitions - 1' calls to 'function_ptr' in parallel. +// Calls 'function_ptr' for first partition inline. +// Uses blocking counter to synchonize threads after parallel calls complete. +// +// The 'partitions' array has a total number of elements equal to +// 'num_partitions * num_partitioned_dims * 2' (the '2' is necessary to specify +// dimension start and limit indices). +// +// The 'partitions' array layout stores array elements in memory with dimension +// start limit as the most-minor dimension, followed by dimension, then +// partition. +// +// EX: Layout of 'partitions' array with 'num_partitions = 2', and +// 'num_partitioned_dims = 3' +// +// [partition0_dim0_start] +// [partition0_dim0_limit] +// [partition0_dim1_start] +// [partition0_dim1_limit] +// [partition0_dim2_start] +// [partition0_dim2_limit] +// [partition1_dim0_start] +// [partition1_dim0_limit] +// [partition1_dim1_start] +// [partition1_dim1_limit] +// [partition1_dim2_start] +// [partition1_dim2_limit] +// +void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, uint64* prof_counters, tensorflow::int32 num_partitions, + tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims, + void* function_ptr) { + VLOG(2) << "ParallelForkJoin ENTRY" + << " num_partitions: " << num_partitions + << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_GT(num_partitions, 1); + CHECK_GT(num_partitioned_dims, 0); + const xla::ExecutableRunOptions* run_options = + static_cast<const xla::ExecutableRunOptions*>(run_options_ptr); + ComputeFunctionType function = + reinterpret_cast<ComputeFunctionType>(function_ptr); + // Compute partition stride in 'partitions' array. + const int64 stride = 2 * num_partitioned_dims; + + // Dispatch 'num_partitions - 1' compute functions to run in parallel. + tensorflow::BlockingCounter bc(num_partitions - 1); + for (tensorflow::int32 i = 1; i < num_partitions; ++i) { + const int64 offset = i * stride; + run_options->intra_op_thread_pool()->enqueue_function( + [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + partitions, offset, &bc]() { + function(result_ptr, run_options_ptr, params, temps, + &partitions[offset], prof_counters); + bc.DecrementCount(); + VLOG(3) << "ParallelForkJoin partition " << i << " done."; + }); + } + + // Call first compute function inline. + function(result_ptr, run_options_ptr, params, temps, &partitions[0], + prof_counters); + VLOG(3) << "ParallelForkJoin partition 0 done."; + bc.Wait(); + VLOG(2) << "ParallelForkJoin EXIT"; +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h new file mode 100644 index 0000000000..1ddcaf5274 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Dispatches 'num_partitions' parallel calls to 'function_ptr' and joins +// threads before returning. See comments in runtime_fork_join.cc for details. +extern void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, uint64* prof_counters, tensorflow::int32 num_partitions, + tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims, + void* function_ptr); + +} // extern "C" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index c614e334a8..cfffb3fbc3 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -104,6 +105,7 @@ class JITSymbolTable { ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32); ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32); ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64); + ADD_JIT_SYMBOL_TO_TABLE(ParallelForkJoin); #undef ADD_JIT_SYMBOL_TO_TABLE } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index f37a331a72..256ec71ab5 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1410,8 +1410,10 @@ xla_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "//third_party/eigen3", ], ) diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 3bf9ccb197..a8f6488996 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -17,8 +17,12 @@ limitations under the License. #include <algorithm> #include <memory> #include <new> +#include <random> #include <utility> +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" @@ -37,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -250,6 +255,42 @@ XLA_TEST_F(FusionTest, Parameter) { ErrorSpec(1e-4)); } +XLA_TEST_F(FusionTest, RandomizedParallelPartition) { + // Tests parallel partitioning of a fusion instruction. + // Create shape with random outer dimension size to generate random parallel + // partition counts for each test run. + const int seed = tensorflow::testing::RandomSeed(); + LOG(INFO) << "RandomizedParallelPartition seed: " << seed; + std::mt19937 generator(seed); + std::uniform_int_distribution<int> distribution(128, 1024); + const int64 rand_dim0_size = distribution(generator); + const int64 dim1_size = 1024; + Shape shape = + ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); + // Build simple fusion computation: y = x^2 (elementwise). + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); + auto x = + builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); + auto y = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x)); + + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two}, + HloInstruction::FusionKind::kLoop); + // Compute result. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + // Every element of result should be y = x^2 = 4.0. + for (int i = 0; i < rand_dim0_size; ++i) { + for (int j = 0; j < dim1_size; ++j) { + EXPECT_EQ(4.0, result->Get<float>({i, j})); + } + } +} + XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); @@ -722,47 +763,104 @@ void BM_ParallelFusion(int num_iters) { auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); StreamExecutorMemoryAllocator allocator(platform, executors); - const int64 intra_op_parallelism_threads = 16; + const int64 intra_op_parallelism_threads = 24; xla::LocalClientOptions client_options; client_options.set_platform(platform); client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); auto client = ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); - const int64 dim_size = 1024; - // Create a simple fusable elementwise computation. + auto* transfer_manager = + TransferManager::GetForPlatform(platform).ValueOrDie(); + int device_ordinal = client->default_device_ordinal(); + + // Computation shape parameters. + const int64 param0_dim0 = 1024; + const int64 param0_dim1 = 1024; + const int64 param1_dim0 = 1024; + const int64 param1_dim1 = 1024; + const int64 param2_dim0 = 1024; + const int64 param2_dim1 = 1024; + + // Create computation. ComputationBuilder builder(client, "ParallelFusion"); - Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); - auto input0 = builder.Broadcast(builder.ConstantR0<float>(1.5f), - AsInt64Slice(input_shape.dimensions())); - auto input1 = builder.Broadcast(builder.ConstantR0<float>(2.0f), - AsInt64Slice(input_shape.dimensions())); - auto input2 = builder.Broadcast(builder.ConstantR0<float>(3.0f), - AsInt64Slice(input_shape.dimensions())); - auto x = builder.Mul(input0, input1); - auto y = builder.Add(x, input2); + Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); + auto param0 = builder.Parameter(0, shape0, "param0"); + Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); + auto param1 = builder.Parameter(1, shape1, "param1"); + Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); + auto param2 = builder.Parameter(2, shape2, "param2"); + + auto x = builder.Mul(param0, param1); + auto y = builder.Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); + // Transfer literals to device. + auto buffer0 = + ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param0_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({}))); + + auto buffer1 = + ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param1_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({}))); + + auto buffer2 = + ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param2_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({}))); + + // Build executable. std::unique_ptr<LocalExecutable> executable = - client->Compile(computation, {}, ExecutableBuildOptions()) + client + ->Compile(computation, + {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); - // Run some warm-up executions. + se::Stream stream(executors[client->default_device_ordinal()]); + stream.Init(); + + // Initialize thread pool. + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", + intra_op_parallelism_threads); + tensorflow::EigenThreadPoolWrapper tp(&pool); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + // Initialize ExecutableRunOptions. ExecutableRunOptions options; - options.set_allocator(&allocator); + options.set_allocator(&allocator).set_stream(&stream); + options.set_intra_op_thread_pool(&device); + + // Run some warm-up executions. const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. - tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) * dim_size * - dim_size * sizeof(float)); + const int64 total_bytes = param0_dim0 * param0_dim0 + + param1_dim0 * param1_dim0 + + param2_dim0 * param2_dim0; + tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) * + total_bytes * sizeof(float)); tensorflow::testing::UseRealTime(); tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } } |