aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-31 17:55:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 18:00:03 -0700
commit1bc856ba29bd57378d5c1ca08afc255460597f7f (patch)
tree64d0aa9efd40f8d9b392d5a24fe074f9d415aefc /tensorflow/compiler/xla/service/cpu/ir_emitter.cc
parent3e03b4946b05e5ebb5158ec360f120e05c82febd (diff)
[XLA:CPU] Don't use "temps" to refer to the table of buffer allocations
Instead call it "buffer table", it now contains both entry computation parameters and temporaries. PiperOrigin-RevId: 211171651
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc41
1 files changed, 20 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 43f2a034ff..e5cf15c686 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -342,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Write the tuple index table.
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
assignment_.GetUniqueSlice(infeed, {0}));
- llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape);
+ llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
assignment_.GetUniqueSlice(infeed, {1}));
- llvm::Value* token_address = EmitTempBufferPointer(
+ llvm::Value* token_address = EmitBufferPointer(
token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_,
module_);
@@ -368,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
// Only the outer tuple buffer's target address is obtained from
// GetEmittedValueFor, to handle the case when Infeed is the root
// instruction. Target addresses for internal elements can be obtained
- // from EmitTempBufferPointer.
+ // from EmitBufferPointer.
llvm::Value* tuple_element_address =
- EmitTempBufferPointer(buffer, tuple_element_shape);
+ EmitBufferPointer(buffer, tuple_element_shape);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(
XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
@@ -1205,7 +1205,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
const Shape& operand_shape = crs->operand(i)->shape();
CHECK(ShapeUtil::IsArray(operand_shape))
<< "Operands to cross-replica-sum must be arrays: " << crs->ToString();
- operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape));
+ operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
// TODO(b/63762267): Be more aggressive about specifying alignment.
MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr,
@@ -2102,7 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
{}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument());
HloInstruction* root = computation->root_instruction();
@@ -2622,15 +2622,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() {
return compute_function_->profile_counters_arg();
}
-llvm::Value* IrEmitter::GetTempBuffersArgument() {
- return compute_function_->temp_buffers_arg();
+llvm::Value* IrEmitter::GetBufferTableArgument() {
+ return compute_function_->buffer_table_arg();
}
llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
+llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
@@ -2689,11 +2689,11 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
+llvm::Value* IrEmitter::EmitGlobalBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
- GetTempBuffersArgument(), slice.index(), &b_);
+ GetBufferTableArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
@@ -2714,14 +2714,14 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
IrShapeType(target_shape)->getPointerTo());
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
+llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape) {
if (slice.allocation()->is_thread_local()) {
- return EmitThreadLocalTempBufferPointer(slice, target_shape);
+ return EmitThreadLocalBufferPointer(slice, target_shape);
} else if (slice.allocation()->is_constant()) {
return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
} else {
- return EmitGlobalTempBufferPointer(slice, target_shape);
+ return EmitGlobalBufferPointer(slice, target_shape);
}
}
@@ -2729,7 +2729,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
const Shape& target_shape = op->shape();
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueTopLevelSlice(op));
- llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
+ llvm::Value* addr = EmitBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2758,8 +2758,7 @@ Status IrEmitter::EmitTargetElementLoop(
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(target_op, {i}));
const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
- llvm::Value* op_target_address =
- EmitTempBufferPointer(slice, element_shape);
+ llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
output_arrays.push_back(
llvm_ir::IrArray(op_target_address, element_shape));
}
@@ -2867,7 +2866,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
parameter_addrs, &b_, name,
/*return_value_buffer=*/return_value_buffer,
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
+ /*buffer_table_arg=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
/*profile_counters_arg=*/GetProfileCountersArgument()));
@@ -2884,7 +2883,7 @@ void IrEmitter::EmitGlobalCall(const HloComputation& callee,
/*return_value_buffer=*/
llvm::Constant::getNullValue(b_.getInt8PtrTy()),
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*buffer_table_arg=*/GetBufferTableArgument(),
/*profile_counters_arg=*/GetProfileCountersArgument()));
}
@@ -2897,7 +2896,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
const BufferAllocation::Slice root_buffer =
assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
- return EmitTempBufferPointer(root_buffer, root_inst->shape());
+ return EmitBufferPointer(root_buffer, root_inst->shape());
}
} // namespace cpu