diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-08-31 17:55:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 18:00:03 -0700 |
commit | 1bc856ba29bd57378d5c1ca08afc255460597f7f (patch) | |
tree | 64d0aa9efd40f8d9b392d5a24fe074f9d415aefc /tensorflow/compiler/xla/service/cpu/ir_emitter.cc | |
parent | 3e03b4946b05e5ebb5158ec360f120e05c82febd (diff) |
[XLA:CPU] Don't use "temps" to refer to the table of buffer allocations
Instead call it "buffer table", it now contains both entry computation
parameters and temporaries.
PiperOrigin-RevId: 211171651
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 41 |
1 files changed, 20 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 43f2a034ff..e5cf15c686 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -342,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Write the tuple index table. TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, assignment_.GetUniqueSlice(infeed, {0})); - llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, assignment_.GetUniqueSlice(infeed, {1})); - llvm::Value* token_address = EmitTempBufferPointer( + llvm::Value* token_address = EmitBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); @@ -368,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root // instruction. Target addresses for internal elements can be obtained - // from EmitTempBufferPointer. + // from EmitBufferPointer. llvm::Value* tuple_element_address = - EmitTempBufferPointer(buffer, tuple_element_shape); + EmitBufferPointer(buffer, tuple_element_shape); TF_RETURN_IF_ERROR(EmitXfeedTransfer( XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); @@ -1205,7 +1205,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { const Shape& operand_shape = crs->operand(i)->shape(); CHECK(ShapeUtil::IsArray(operand_shape)) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, @@ -2102,7 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), /*profile_counters_arg=*/GetProfileCountersArgument()); HloInstruction* root = computation->root_instruction(); @@ -2622,15 +2622,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() { return compute_function_->profile_counters_arg(); } -llvm::Value* IrEmitter::GetTempBuffersArgument() { - return compute_function_->temp_buffers_arg(); +llvm::Value* IrEmitter::GetBufferTableArgument() { + return compute_function_->buffer_table_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address = [&]() -> llvm::Value* { @@ -2689,11 +2689,11 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( +llvm::Value* IrEmitter::EmitGlobalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), slice.index(), &b_); + GetBufferTableArgument(), slice.index(), &b_); llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { @@ -2714,14 +2714,14 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { +llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape) { if (slice.allocation()->is_thread_local()) { - return EmitThreadLocalTempBufferPointer(slice, target_shape); + return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); } else { - return EmitGlobalTempBufferPointer(slice, target_shape); + return EmitGlobalBufferPointer(slice, target_shape); } } @@ -2729,7 +2729,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { const Shape& target_shape = op->shape(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); - llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); + llvm::Value* addr = EmitBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2758,8 +2758,7 @@ Status IrEmitter::EmitTargetElementLoop( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, assignment_.GetUniqueSlice(target_op, {i})); const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); - llvm::Value* op_target_address = - EmitTempBufferPointer(slice, element_shape); + llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape); output_arrays.push_back( llvm_ir::IrArray(op_target_address, element_shape)); } @@ -2867,7 +2866,7 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( parameter_addrs, &b_, name, /*return_value_buffer=*/return_value_buffer, /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ + /*buffer_table_arg=*/ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), /*profile_counters_arg=*/GetProfileCountersArgument())); @@ -2884,7 +2883,7 @@ void IrEmitter::EmitGlobalCall(const HloComputation& callee, /*return_value_buffer=*/ llvm::Constant::getNullValue(b_.getInt8PtrTy()), /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), /*profile_counters_arg=*/GetProfileCountersArgument())); } @@ -2897,7 +2896,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( const BufferAllocation::Slice root_buffer = assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); - return EmitTempBufferPointer(root_buffer, root_inst->shape()); + return EmitBufferPointer(root_buffer, root_inst->shape()); } } // namespace cpu |