diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc | 47 |
1 files changed, 33 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index d420863b85..8c11cd0541 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -18,8 +18,10 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -39,7 +41,7 @@ void HloToIrBindings::EmitBasePointersForHlos( // I/O HLOs are bound to the arguments of the current IR function. I.e., // // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { - llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); + llvm::Function* function = b_->GetInsertBlock()->getParent(); CHECK_EQ(io_hlos.size() + 1, function->arg_size()); // An HLO can have duplicated operands. This data structure remembers which @@ -79,8 +81,8 @@ void HloToIrBindings::EmitBasePointersForHlos( const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); // Emit IR for GetTupleElement instruction and bind to emitted value. - llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP( - temp_buffer_base_, ir_builder_->getInt64(offset)); + llvm::Value* base_ptr = + b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)); BindHloToIrValue(*non_io_hlo, EmitGetTupleElement(non_io_hlo, base_ptr)); } @@ -108,15 +110,20 @@ void HloToIrBindings::EmitBasePointersForHlos( if (slice.allocation()->is_thread_local()) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); - BindHloToIrValue(*non_io_hlo, - ir_builder_->CreateAlloca(pointee_type), index); + BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), + index); + } else if (slice.allocation()->is_constant()) { + llvm::Value* global_for_constant = + module_->getGlobalVariable(llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); BindHloToIrValue( *non_io_hlo, - ir_builder_->CreateInBoundsGEP(temp_buffer_base_, - ir_builder_->getInt64(offset)), + b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)), index); } }); @@ -129,11 +136,19 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_); + EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); +} + +// Returns true if `value` has a name that should not be changed. +static bool HasMeaningfulName(llvm::Value* value) { + if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) { + return global->getLinkage() != llvm::GlobalValue::PrivateLinkage; + } + return false; } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, @@ -145,14 +160,18 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, llvm::Value* typed_ir_value; if (llvm::isa<llvm::GlobalVariable>(ir_value)) { - typed_ir_value = llvm::ConstantExpr::getBitCast( + typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( llvm::cast<llvm::GlobalVariable>(ir_value), dest_type); } else { - typed_ir_value = - ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo()); + typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); + } + if (!HasMeaningfulName(ir_value)) { + ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + } + if (!HasMeaningfulName(typed_ir_value)) { + typed_ir_value->setName( + llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); } - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); - typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); return typed_ir_value; } |