aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc22
1 files changed, 10 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index d420863b85..1b6315ec03 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -39,7 +39,7 @@ void HloToIrBindings::EmitBasePointersForHlos(
// I/O HLOs are bound to the arguments of the current IR function. I.e.,
//
// void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
- llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
+ llvm::Function* function = b_->GetInsertBlock()->getParent();
CHECK_EQ(io_hlos.size() + 1, function->arg_size());
// An HLO can have duplicated operands. This data structure remembers which
@@ -79,8 +79,8 @@ void HloToIrBindings::EmitBasePointersForHlos(
const int64 offset = slice.offset();
CHECK_NE(nullptr, temp_buffer_base_);
// Emit IR for GetTupleElement instruction and bind to emitted value.
- llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP(
- temp_buffer_base_, ir_builder_->getInt64(offset));
+ llvm::Value* base_ptr =
+ b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset));
BindHloToIrValue(*non_io_hlo,
EmitGetTupleElement(non_io_hlo, base_ptr));
}
@@ -108,15 +108,14 @@ void HloToIrBindings::EmitBasePointersForHlos(
if (slice.allocation()->is_thread_local()) {
llvm::Type* pointee_type =
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
- BindHloToIrValue(*non_io_hlo,
- ir_builder_->CreateAlloca(pointee_type), index);
+ BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type),
+ index);
} else {
const int64 offset = slice.offset();
CHECK_NE(nullptr, temp_buffer_base_);
BindHloToIrValue(
*non_io_hlo,
- ir_builder_->CreateInBoundsGEP(temp_buffer_base_,
- ir_builder_->getInt64(offset)),
+ b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)),
index);
}
});
@@ -129,11 +128,11 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
- GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_);
+ GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_);
}
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
- EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_);
+ EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_);
}
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
@@ -145,11 +144,10 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
llvm::Value* typed_ir_value;
if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
- typed_ir_value = llvm::ConstantExpr::getBitCast(
+ typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
} else {
- typed_ir_value =
- ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo());
+ typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo());
}
ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw")));
typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed")));