aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-31 17:55:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 18:00:03 -0700
commit1bc856ba29bd57378d5c1ca08afc255460597f7f (patch)
tree64d0aa9efd40f8d9b392d5a24fe074f9d415aefc /tensorflow/compiler/xla/service/cpu
parent3e03b4946b05e5ebb5158ec360f120e05c82febd (diff)
[XLA:CPU] Don't use "temps" to refer to the table of buffer allocations
Instead call it "buffer table", it now contains both entry computation parameters and temporaries. PiperOrigin-RevId: 211171651
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu')
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc19
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h23
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc41
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h29
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc25
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h9
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.h2
9 files changed, 79 insertions, 80 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 6420180b13..796f36510e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -588,8 +588,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
ScheduleComputationsInModule(*module, BufferSizeBytesFunction(),
DFSMemoryScheduler));
- // Run buffer analysis on the HLO graph. This analysis figures out which
- // temporary buffers are required to run the computation.
+ // Run buffer allocation on the HLO graph.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(module.get(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 9b00f2eaa5..29abf38e43 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -75,7 +75,7 @@ CpuExecutable::CpuExecutable(
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
-CpuExecutable::CreateTempArray(
+CpuExecutable::CreateBufferTable(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
absl::Span<const ShapedBuffer* const> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
@@ -141,14 +141,14 @@ Status CpuExecutable::ExecuteComputeFunction(
// The calling convention for JITed functions is:
//
// void function(void* result, const void* run_options, void** args_array,
- // void** temps_array)
+ // void** buffer_table)
//
// result: Points at the result.
// run_options: the ExecutableRunOptions object.
// args_array: null
- // temps_array: An array of pointers, containing pointers to temporary buffers
- // required by the executable adn pointers to entry computation
- // parameters.
+ // buffer_table: An array of pointers, containing pointers to temporary
+ // buffers required by the executable adn pointers to entry computation
+ // parameters.
//
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
@@ -172,7 +172,7 @@ Status CpuExecutable::ExecuteComputeFunction(
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
VLOG(3) << absl::StrFormat(
- " func(void* result, void* params[null], void* temps[%u], "
+ " func(void* result, void* params[null], void* buffer_table[%u], "
"uint64 profile_counters[%u])",
buffer_pointers.size(), profile_counters_size);
VLOG(3) << absl::StrFormat(" result = %p", result_buffer);
@@ -181,7 +181,8 @@ Status CpuExecutable::ExecuteComputeFunction(
};
VLOG(3) << " params = nullptr";
VLOG(3) << absl::StrFormat(
- " temps = [%s]", absl::StrJoin(buffer_pointers, ", ", ptr_printer));
+ " buffer_table = [%s]",
+ absl::StrJoin(buffer_pointers, ", ", ptr_printer));
VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters);
}
@@ -281,8 +282,8 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
std::vector<se::DeviceMemoryBase> unowning_buffers;
TF_ASSIGN_OR_RETURN(
std::tie(unowning_buffers, owning_buffers),
- CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
- arguments));
+ CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(),
+ arguments));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer result,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 3571513e02..3c3c047bfe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -74,9 +74,10 @@ class CpuExecutable : public Executable {
static int64 ShapeSizeBytes(const Shape& shape);
// Type of the computation function we expect in the JIT.
- using ComputeFunctionType = void (*)(
- void* /*result*/, const ExecutableRunOptions* /*run_options*/,
- const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/);
+ using ComputeFunctionType =
+ void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/,
+ const void** /*args*/, void** /*buffer_table*/,
+ int64* /*profile_counters*/);
const ComputeFunctionType& compute_function() const {
return compute_function_;
@@ -95,15 +96,15 @@ class CpuExecutable : public Executable {
absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile);
- // Creates an array suitable for passing as the "temps" argument to the JIT
- // compiled function pointer.
+ // Creates an array suitable for passing as the "buffer_table" argument to the
+ // JIT compiled function pointer.
//
// Returns (unowning_buffers, owning_buffers) where:
//
- // - unowning_buffers.data() can be passed as the temps argument as-is and
- // includes pointers to the scratch storage required by the computation,
- // the live-out buffer into which the result will be written and entry
- // computation parameters.
+ // - unowning_buffers.data() can be passed as the buffer_table argument as-is
+ // and includes pointers to the scratch storage required by the
+ // computation, the live-out buffer into which the result will be written
+ // and entry computation parameters.
//
// - owning_buffers contains owning pointers to the buffers that were
// allocated by this routine. This routine allocates buffers for temporary
@@ -111,8 +112,8 @@ class CpuExecutable : public Executable {
// result.
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
std::vector<OwningDeviceMemory>>>
- CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- absl::Span<const ShapedBuffer* const> arguments);
+ CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ absl::Span<const ShapedBuffer* const> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 43f2a034ff..e5cf15c686 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -342,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Write the tuple index table.
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
assignment_.GetUniqueSlice(infeed, {0}));
- llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape);
+ llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
assignment_.GetUniqueSlice(infeed, {1}));
- llvm::Value* token_address = EmitTempBufferPointer(
+ llvm::Value* token_address = EmitBufferPointer(
token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_,
module_);
@@ -368,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Only the outer tuple buffer's target address is obtained from
// GetEmittedValueFor, to handle the case when Infeed is the root
// instruction. Target addresses for internal elements can be obtained
- // from EmitTempBufferPointer.
+ // from EmitBufferPointer.
llvm::Value* tuple_element_address =
- EmitTempBufferPointer(buffer, tuple_element_shape);
+ EmitBufferPointer(buffer, tuple_element_shape);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(
XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
@@ -1205,7 +1205,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
const Shape& operand_shape = crs->operand(i)->shape();
CHECK(ShapeUtil::IsArray(operand_shape))
<< "Operands to cross-replica-sum must be arrays: " << crs->ToString();
- operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
+ operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
@@ -2102,7 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
{}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument());
HloInstruction* root = computation->root_instruction();
@@ -2622,15 +2622,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() {
return compute_function_->profile_counters_arg();
}
-llvm::Value* IrEmitter::GetTempBuffersArgument() {
- return compute_function_->temp_buffers_arg();
+llvm::Value* IrEmitter::GetBufferTableArgument() {
+ return compute_function_->buffer_table_arg();
}
llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
+llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
@@ -2689,11 +2689,11 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
+llvm::Value* IrEmitter::EmitGlobalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
- GetTempBuffersArgument(), slice.index(), &b_);
+ GetBufferTableArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
@@ -2714,14 +2714,14 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
+llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape) {
if (slice.allocation()->is_thread_local()) {
- return EmitThreadLocalTempBufferPointer(slice, target_shape);
+ return EmitThreadLocalBufferPointer(slice, target_shape);
} else if (slice.allocation()->is_constant()) {
return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
} else {
- return EmitGlobalTempBufferPointer(slice, target_shape);
+ return EmitGlobalBufferPointer(slice, target_shape);
}
}
@@ -2729,7 +2729,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
const Shape& target_shape = op->shape();
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(op));
- llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
+ llvm::Value* addr = EmitBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2758,8 +2758,7 @@ Status IrEmitter::EmitTargetElementLoop(
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(target_op, {i}));
const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
- llvm::Value* op_target_address =
- EmitTempBufferPointer(slice, element_shape);
+ llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
output_arrays.push_back(
llvm_ir::IrArray(op_target_address, element_shape));
}
@@ -2867,7 +2866,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
parameter_addrs, &b_, name,
/*return_value_buffer=*/return_value_buffer,
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
+ /*buffer_table_arg=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
/*profile_counters_arg=*/GetProfileCountersArgument()));
@@ -2884,7 +2883,7 @@ void IrEmitter::EmitGlobalCall(const HloComputation& callee,
/*return_value_buffer=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()),
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument()));
}
@@ -2897,7 +2896,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
const BufferAllocation::Slice root_buffer =
assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
- return EmitTempBufferPointer(root_buffer, root_inst->shape());
+ return EmitBufferPointer(root_buffer, root_inst->shape());
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 6efd7fd001..58a333b8fb 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -62,8 +62,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Create a new LLVM IR emitter.
//
// hlo_module: the HLO module we are emitting IR for.
- // assignment: a BufferAssignment from which we know which temporary buffers
- // are used by the HLO nodes.
+ // assignment: a BufferAssignment from which we know which buffers are used by
+ // the HLO nodes.
// llvm_module: the LLVM module to emit IR into.
// instruction_to_profile_idx: the mapping from HLO instructions to their
// index in the profiling array.
@@ -219,24 +219,21 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// argument of the computation function being emitted by this emitter.
llvm::Value* GetExecutableRunOptionsArgument();
- // Get the llvm::Value* that represents the "temps" argument of the
+ // Get the llvm::Value* that represents the "buffer_table" argument of the
// computation function being emitted by this emitter.
- llvm::Value* GetTempBuffersArgument();
+ llvm::Value* GetBufferTableArgument();
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
+ // Helper for EmitBufferPointer.
+ llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitThreadLocalTempBufferPointer(
+ // Helper for EmitBufferPointer.
+ llvm::Value* EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape);
// Emits code that computes the address of the given buffer allocation slice.
- //
- // TODO(sanjoy): This should be renamed to reflect that it no longer provides
- // access to just temporaries.
- llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
+ llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
// Emits a function into the current module. This can be used for
// computations embedded inside other computations, such as the
@@ -390,8 +387,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& source_array);
- // Assignment of the temporary buffers needed by the computation and their
- // shape information.
+ // Assignment of the buffers needed by the computation and their shape
+ // information.
const BufferAssignment& assignment_;
// The LLVM module into which IR will be emitted.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 3ecf4b69b7..adfb8392bf 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -78,19 +78,20 @@ void IrFunction::Initialize(const string& function_name,
const bool optimize_for_size_requested,
const bool enable_fast_math) {
// The function signature is:
- // void function(i8* retval, i8* run_options, i8** params, i8** temps,
+ // void function(i8* retval, i8* run_options, i8** params, i8**
+ // buffer_table,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
// For thread local functions:
// retval: points to the returned value.
// params: address of an array with pointers to parameters.
- // temps: is null
+ // buffer_table: is null
//
// For global functions:
// retval: is null
// params: is null
- // temps: address of an array with pointers to temporary buffers and entry
- // computation parameters.
+ // buffer_table: address of an array with pointers to temporary buffers and
+ // entry computation parameters (but not to constant buffers).
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
@@ -116,7 +117,7 @@ void IrFunction::Initialize(const string& function_name,
// \---------/ \---------/ \-----------/
//
// /---------------------------------------------\
- // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 |
+ // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 |
// | addr | addr | | addr |
// \---------------------------------------------/
// | | |
@@ -134,9 +135,9 @@ void IrFunction::Initialize(const string& function_name,
// prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
// \---------------------------------------------/
- // Even though the type of params and temps is void** in the host's view, in
- // LLVM IR this is represented by i8*, similarly to void*. It's up to the code
- // to use GEPs to unravel the indirection layers.
+ // Even though the type of params and buffer_table is void** in the host's
+ // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
+ // the code to use GEPs to unravel the indirection layers.
llvm::FunctionType* function_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
/*Params=*/
@@ -160,8 +161,8 @@ void IrFunction::Initialize(const string& function_name,
exec_run_options_arg_ = &*arg_iter;
(++arg_iter)->setName("params");
parameters_arg_ = &*arg_iter;
- (++arg_iter)->setName("temps");
- temp_buffers_arg_ = &*arg_iter;
+ (++arg_iter)->setName("buffer_table");
+ buffer_table_arg_ = &*arg_iter;
if (num_dynamic_loop_bounds_ > 0) {
(++arg_iter)->setName("dynamic_loop_bounds");
dynamic_loop_bounds_arg_ = &*arg_iter;
@@ -202,7 +203,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
absl::string_view name, llvm::Value* return_value_buffer,
- llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg,
+ llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
llvm::Value* profile_counters_arg) {
llvm::Value* parameter_addresses_buffer;
@@ -230,7 +231,7 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
};
std::vector<llvm::Value*> arguments{
to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
- parameter_addresses_buffer, temp_buffers_arg};
+ parameter_addresses_buffer, buffer_table_arg};
if (profile_counters_arg != nullptr) {
arguments.push_back(profile_counters_arg);
}
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index 28c69c85a9..623a5f185f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -80,8 +80,9 @@ class IrFunction {
// Get the llvm::Value* that represents this functions parameters argument.
llvm::Value* parameters_arg() { return parameters_arg_; }
- // Get the llvm::Value* that represents this functions "temps" argument.
- llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; }
+ // Get the llvm::Value* that represents this functions "buffer_table"
+ // argument.
+ llvm::Value* buffer_table_arg() { return buffer_table_arg_; }
// Get the llvm::Value* that represents this functions "prof_counters"
// argument.
@@ -108,7 +109,7 @@ class IrFunction {
llvm::Argument* result_arg_;
llvm::Value* exec_run_options_arg_;
llvm::Value* parameters_arg_;
- llvm::Value* temp_buffers_arg_;
+ llvm::Value* buffer_table_arg_;
llvm::Value* dynamic_loop_bounds_arg_ = nullptr;
llvm::Value* profile_counters_arg_;
};
@@ -117,7 +118,7 @@ class IrFunction {
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
absl::string_view name, llvm::Value* return_value_buffer,
- llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg,
+ llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
llvm::Value* profile_counters_arg);
// Emits a call to a runtime fork/join function which dispatches parallel
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
index a5f34908d7..2d9492eacf 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
@@ -61,7 +61,7 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
//
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
- void** temps, uint64* prof_counters, int32 num_partitions,
+ void** buffer_table, uint64* prof_counters, int32 num_partitions,
int64* partitions, int32 num_partitioned_dims, void* function_ptr) {
VLOG(2) << "ParallelForkJoin ENTRY"
<< " num_partitions: " << num_partitions
@@ -81,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
for (int32 i = 1; i < num_partitions; ++i) {
const int64 offset = i * stride;
run_options->intra_op_thread_pool()->enqueueNoNotification(
- [i, function, result_ptr, run_options_ptr, temps, prof_counters,
+ [i, function, result_ptr, run_options_ptr, buffer_table, prof_counters,
partitions, offset, &bc]() {
- function(result_ptr, run_options_ptr, nullptr, temps,
+ function(result_ptr, run_options_ptr, nullptr, buffer_table,
&partitions[offset], prof_counters);
bc.DecrementCount();
VLOG(3) << "ParallelForkJoin partition " << i << " done.";
@@ -91,7 +91,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
}
// Call first compute function inline.
- function(result_ptr, run_options_ptr, params, temps, &partitions[0],
+ function(result_ptr, run_options_ptr, params, buffer_table, &partitions[0],
prof_counters);
VLOG(3) << "ParallelForkJoin partition 0 done.";
bc.Wait();
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
index 1cf0ec6e3d..a279c7d2d6 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
@@ -24,7 +24,7 @@ extern "C" {
// threads before returning. See comments in runtime_fork_join.cc for details.
extern void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
- void** temps, tensorflow::uint64* prof_counters,
+ void** buffer_table, tensorflow::uint64* prof_counters,
tensorflow::int32 num_partitions, tensorflow::int64* partitions,
tensorflow::int32 num_partitioned_dims, void* function_ptr);