aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc47
1 files changed, 33 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index d420863b85..8c11cd0541 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -39,7 +41,7 @@ void HloToIrBindings::EmitBasePointersForHlos(
// I/O HLOs are bound to the arguments of the current IR function. I.e.,
//
// void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
- llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
+ llvm::Function* function = b_->GetInsertBlock()->getParent();
CHECK_EQ(io_hlos.size() + 1, function->arg_size());
// An HLO can have duplicated operands. This data structure remembers which
@@ -79,8 +81,8 @@ void HloToIrBindings::EmitBasePointersForHlos(
const int64 offset = slice.offset();
CHECK_NE(nullptr, temp_buffer_base_);
// Emit IR for GetTupleElement instruction and bind to emitted value.
- llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP(
- temp_buffer_base_, ir_builder_->getInt64(offset));
+ llvm::Value* base_ptr =
+ b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset));
BindHloToIrValue(*non_io_hlo,
EmitGetTupleElement(non_io_hlo, base_ptr));
}
@@ -108,15 +110,20 @@ void HloToIrBindings::EmitBasePointersForHlos(
if (slice.allocation()->is_thread_local()) {
llvm::Type* pointee_type =
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
- BindHloToIrValue(*non_io_hlo,
- ir_builder_->CreateAlloca(pointee_type), index);
+ BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type),
+ index);
+ } else if (slice.allocation()->is_constant()) {
+ llvm::Value* global_for_constant =
+ module_->getGlobalVariable(llvm_ir::AsStringRef(
+ llvm_ir::ConstantBufferAllocationToGlobalName(
+ *slice.allocation())));
+ BindHloToIrValue(*non_io_hlo, global_for_constant);
} else {
const int64 offset = slice.offset();
CHECK_NE(nullptr, temp_buffer_base_);
BindHloToIrValue(
*non_io_hlo,
- ir_builder_->CreateInBoundsGEP(temp_buffer_base_,
- ir_builder_->getInt64(offset)),
+ b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)),
index);
}
});
@@ -129,11 +136,19 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
- GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_);
+ GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_);
}
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
- EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_);
+ EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_);
+}
+
+// Returns true if `value` has a name that should not be changed.
+static bool HasMeaningfulName(llvm::Value* value) {
+ if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
+ return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
+ }
+ return false;
}
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
@@ -145,14 +160,18 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
llvm::Value* typed_ir_value;
if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
- typed_ir_value = llvm::ConstantExpr::getBitCast(
+ typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
} else {
- typed_ir_value =
- ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo());
+ typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo());
+ }
+ if (!HasMeaningfulName(ir_value)) {
+ ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw")));
+ }
+ if (!HasMeaningfulName(typed_ir_value)) {
+ typed_ir_value->setName(
+ llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed")));
}
- ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw")));
- typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed")));
return typed_ir_value;
}