diff options
author | 2018-08-31 17:55:40 -0700 | |
---|---|---|
committer | 2018-08-31 18:00:03 -0700 | |
commit | 1bc856ba29bd57378d5c1ca08afc255460597f7f (patch) | |
tree | 64d0aa9efd40f8d9b392d5a24fe074f9d415aefc /tensorflow/compiler/xla/service/cpu | |
parent | 3e03b4946b05e5ebb5158ec360f120e05c82febd (diff) |
[XLA:CPU] Don't use "temps" to refer to the table of buffer allocations
Instead call it "buffer table", it now contains both entry computation
parameters and temporaries.
PiperOrigin-RevId: 211171651
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu')
9 files changed, 79 insertions, 80 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 6420180b13..796f36510e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -588,8 +588,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. + // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> assignment, BufferAssigner::Run(module.get(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 9b00f2eaa5..29abf38e43 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -75,7 +75,7 @@ CpuExecutable::CpuExecutable( StatusOr<std::pair<std::vector<se::DeviceMemoryBase>, std::vector<OwningDeviceMemory>>> -CpuExecutable::CreateTempArray( +CpuExecutable::CreateBufferTable( DeviceMemoryAllocator* memory_allocator, int device_ordinal, absl::Span<const ShapedBuffer* const> arguments) { std::vector<se::DeviceMemoryBase> unowning_buffers( @@ -141,14 +141,14 @@ Status CpuExecutable::ExecuteComputeFunction( // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, - // void** temps_array) + // void** buffer_table) // // result: Points at the result. // run_options: the ExecutableRunOptions object. // args_array: null - // temps_array: An array of pointers, containing pointers to temporary buffers - // required by the executable adn pointers to entry computation - // parameters. + // buffer_table: An array of pointers, containing pointers to temporary + // buffers required by the executable adn pointers to entry computation + // parameters. // uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -172,7 +172,7 @@ Status CpuExecutable::ExecuteComputeFunction( if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << absl::StrFormat( - " func(void* result, void* params[null], void* temps[%u], " + " func(void* result, void* params[null], void* buffer_table[%u], " "uint64 profile_counters[%u])", buffer_pointers.size(), profile_counters_size); VLOG(3) << absl::StrFormat(" result = %p", result_buffer); @@ -181,7 +181,8 @@ Status CpuExecutable::ExecuteComputeFunction( }; VLOG(3) << " params = nullptr"; VLOG(3) << absl::StrFormat( - " temps = [%s]", absl::StrJoin(buffer_pointers, ", ", ptr_printer)); + " buffer_table = [%s]", + absl::StrJoin(buffer_pointers, ", ", ptr_printer)); VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters); } @@ -281,8 +282,8 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl( std::vector<se::DeviceMemoryBase> unowning_buffers; TF_ASSIGN_OR_RETURN( std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); + CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(), + arguments)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer result, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 3571513e02..3c3c047bfe 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -74,9 +74,10 @@ class CpuExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. - using ComputeFunctionType = void (*)( - void* /*result*/, const ExecutableRunOptions* /*run_options*/, - const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/); + using ComputeFunctionType = + void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/, + const void** /*args*/, void** /*buffer_table*/, + int64* /*profile_counters*/); const ComputeFunctionType& compute_function() const { return compute_function_; @@ -95,15 +96,15 @@ class CpuExecutable : public Executable { absl::Span<const ShapedBuffer* const> arguments, HloExecutionProfile* hlo_execution_profile); - // Creates an array suitable for passing as the "temps" argument to the JIT - // compiled function pointer. + // Creates an array suitable for passing as the "buffer_table" argument to the + // JIT compiled function pointer. // // Returns (unowning_buffers, owning_buffers) where: // - // - unowning_buffers.data() can be passed as the temps argument as-is and - // includes pointers to the scratch storage required by the computation, - // the live-out buffer into which the result will be written and entry - // computation parameters. + // - unowning_buffers.data() can be passed as the buffer_table argument as-is + // and includes pointers to the scratch storage required by the + // computation, the live-out buffer into which the result will be written + // and entry computation parameters. // // - owning_buffers contains owning pointers to the buffers that were // allocated by this routine. This routine allocates buffers for temporary @@ -111,8 +112,8 @@ class CpuExecutable : public Executable { // result. StatusOr<std::pair<std::vector<se::DeviceMemoryBase>, std::vector<OwningDeviceMemory>>> - CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - absl::Span<const ShapedBuffer* const> arguments); + CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + absl::Span<const ShapedBuffer* const> arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 43f2a034ff..e5cf15c686 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -342,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Write the tuple index table. TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, assignment_.GetUniqueSlice(infeed, {0})); - llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, assignment_.GetUniqueSlice(infeed, {1})); - llvm::Value* token_address = EmitTempBufferPointer( + llvm::Value* token_address = EmitBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); @@ -368,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root // instruction. Target addresses for internal elements can be obtained - // from EmitTempBufferPointer. + // from EmitBufferPointer. llvm::Value* tuple_element_address = - EmitTempBufferPointer(buffer, tuple_element_shape); + EmitBufferPointer(buffer, tuple_element_shape); TF_RETURN_IF_ERROR(EmitXfeedTransfer( XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); @@ -1205,7 +1205,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { const Shape& operand_shape = crs->operand(i)->shape(); CHECK(ShapeUtil::IsArray(operand_shape)) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, @@ -2102,7 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), /*profile_counters_arg=*/GetProfileCountersArgument()); HloInstruction* root = computation->root_instruction(); @@ -2622,15 +2622,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() { return compute_function_->profile_counters_arg(); } -llvm::Value* IrEmitter::GetTempBuffersArgument() { - return compute_function_->temp_buffers_arg(); +llvm::Value* IrEmitter::GetBufferTableArgument() { + return compute_function_->buffer_table_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address = [&]() -> llvm::Value* { @@ -2689,11 +2689,11 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( +llvm::Value* IrEmitter::EmitGlobalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), slice.index(), &b_); + GetBufferTableArgument(), slice.index(), &b_); llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { @@ -2714,14 +2714,14 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { +llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape) { if (slice.allocation()->is_thread_local()) { - return EmitThreadLocalTempBufferPointer(slice, target_shape); + return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); } else { - return EmitGlobalTempBufferPointer(slice, target_shape); + return EmitGlobalBufferPointer(slice, target_shape); } } @@ -2729,7 +2729,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { const Shape& target_shape = op->shape(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); - llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); + llvm::Value* addr = EmitBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2758,8 +2758,7 @@ Status IrEmitter::EmitTargetElementLoop( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, assignment_.GetUniqueSlice(target_op, {i})); const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); - llvm::Value* op_target_address = - EmitTempBufferPointer(slice, element_shape); + llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape); output_arrays.push_back( llvm_ir::IrArray(op_target_address, element_shape)); } @@ -2867,7 +2866,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( parameter_addrs, &b_, name, /*return_value_buffer=*/return_value_buffer, /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ + /*buffer_table_arg=*/ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), /*profile_counters_arg=*/GetProfileCountersArgument())); @@ -2884,7 +2883,7 @@ void IrEmitter::EmitGlobalCall(const HloComputation& callee, /*return_value_buffer=*/ llvm::Constant::getNullValue(b_.getInt8PtrTy()), /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), /*profile_counters_arg=*/GetProfileCountersArgument())); } @@ -2897,7 +2896,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( const BufferAllocation::Slice root_buffer = assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); - return EmitTempBufferPointer(root_buffer, root_inst->shape()); + return EmitBufferPointer(root_buffer, root_inst->shape()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 6efd7fd001..58a333b8fb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -62,8 +62,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Create a new LLVM IR emitter. // // hlo_module: the HLO module we are emitting IR for. - // assignment: a BufferAssignment from which we know which temporary buffers - // are used by the HLO nodes. + // assignment: a BufferAssignment from which we know which buffers are used by + // the HLO nodes. // llvm_module: the LLVM module to emit IR into. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. @@ -219,24 +219,21 @@ class IrEmitter : public DfsHloVisitorWithDefault, // argument of the computation function being emitted by this emitter. llvm::Value* GetExecutableRunOptionsArgument(); - // Get the llvm::Value* that represents the "temps" argument of the + // Get the llvm::Value* that represents the "buffer_table" argument of the // computation function being emitted by this emitter. - llvm::Value* GetTempBuffersArgument(); + llvm::Value* GetBufferTableArgument(); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + // Helper for EmitBufferPointer. + llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitThreadLocalTempBufferPointer( + // Helper for EmitBufferPointer. + llvm::Value* EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape); // Emits code that computes the address of the given buffer allocation slice. - // - // TODO(sanjoy): This should be renamed to reflect that it no longer provides - // access to just temporaries. - llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); // Emits a function into the current module. This can be used for // computations embedded inside other computations, such as the @@ -390,8 +387,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); - // Assignment of the temporary buffers needed by the computation and their - // shape information. + // Assignment of the buffers needed by the computation and their shape + // information. const BufferAssignment& assignment_; // The LLVM module into which IR will be emitted. diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 3ecf4b69b7..adfb8392bf 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -78,19 +78,20 @@ void IrFunction::Initialize(const string& function_name, const bool optimize_for_size_requested, const bool enable_fast_math) { // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // void function(i8* retval, i8* run_options, i8** params, i8** + // buffer_table, // i64* dynamic_loop_bounds, i64* prof_counters) // // For thread local functions: // retval: points to the returned value. // params: address of an array with pointers to parameters. - // temps: is null + // buffer_table: is null // // For global functions: // retval: is null // params: is null - // temps: address of an array with pointers to temporary buffers and entry - // computation parameters. + // buffer_table: address of an array with pointers to temporary buffers and + // entry computation parameters (but not to constant buffers). // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -116,7 +117,7 @@ void IrFunction::Initialize(const string& function_name, // \---------/ \---------/ \-----------/ // // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 | // | addr | addr | | addr | // \---------------------------------------------/ // | | | @@ -134,9 +135,9 @@ void IrFunction::Initialize(const string& function_name, // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | // \---------------------------------------------/ - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. + // Even though the type of params and buffer_table is void** in the host's + // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to + // the code to use GEPs to unravel the indirection layers. llvm::FunctionType* function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), /*Params=*/ @@ -160,8 +161,8 @@ void IrFunction::Initialize(const string& function_name, exec_run_options_arg_ = &*arg_iter; (++arg_iter)->setName("params"); parameters_arg_ = &*arg_iter; - (++arg_iter)->setName("temps"); - temp_buffers_arg_ = &*arg_iter; + (++arg_iter)->setName("buffer_table"); + buffer_table_arg_ = &*arg_iter; if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); dynamic_loop_bounds_arg_ = &*arg_iter; @@ -202,7 +203,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { std::vector<llvm::Value*> GetArrayFunctionCallArguments( absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, - llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; @@ -230,7 +231,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( }; std::vector<llvm::Value*> arguments{ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), - parameter_addresses_buffer, temp_buffers_arg}; + parameter_addresses_buffer, buffer_table_arg}; if (profile_counters_arg != nullptr) { arguments.push_back(profile_counters_arg); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 28c69c85a9..623a5f185f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -80,8 +80,9 @@ class IrFunction { // Get the llvm::Value* that represents this functions parameters argument. llvm::Value* parameters_arg() { return parameters_arg_; } - // Get the llvm::Value* that represents this functions "temps" argument. - llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + // Get the llvm::Value* that represents this functions "buffer_table" + // argument. + llvm::Value* buffer_table_arg() { return buffer_table_arg_; } // Get the llvm::Value* that represents this functions "prof_counters" // argument. @@ -108,7 +109,7 @@ class IrFunction { llvm::Argument* result_arg_; llvm::Value* exec_run_options_arg_; llvm::Value* parameters_arg_; - llvm::Value* temp_buffers_arg_; + llvm::Value* buffer_table_arg_; llvm::Value* dynamic_loop_bounds_arg_ = nullptr; llvm::Value* profile_counters_arg_; }; @@ -117,7 +118,7 @@ class IrFunction { std::vector<llvm::Value*> GetArrayFunctionCallArguments( absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, - llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, llvm::Value* profile_counters_arg); // Emits a call to a runtime fork/join function which dispatches parallel diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index a5f34908d7..2d9492eacf 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -61,7 +61,7 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, uint64* prof_counters, int32 num_partitions, + void** buffer_table, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions @@ -81,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, buffer_table, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, nullptr, temps, + function(result_ptr, run_options_ptr, nullptr, buffer_table, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; @@ -91,7 +91,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( } // Call first compute function inline. - function(result_ptr, run_options_ptr, params, temps, &partitions[0], + function(result_ptr, run_options_ptr, params, buffer_table, &partitions[0], prof_counters); VLOG(3) << "ParallelForkJoin partition 0 done."; bc.Wait(); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h index 1cf0ec6e3d..a279c7d2d6 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -24,7 +24,7 @@ extern "C" { // threads before returning. See comments in runtime_fork_join.cc for details. extern void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, tensorflow::uint64* prof_counters, + void** buffer_table, tensorflow::uint64* prof_counters, tensorflow::int32 num_partitions, tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims, void* function_ptr); |