aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc39
1 files changed, 27 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index d909845a3a..72ede377e1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -52,7 +52,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
// that would be regenerated without caching. But this might increase the
// JIT compilation time.
if (generated_value_bb == nullptr ||
- generated_value_bb == ir_builder_->GetInsertBlock()) {
+ generated_value_bb == b_->GetInsertBlock()) {
VLOG(3) << "The cached generated value is reused.";
return generated_value;
}
@@ -60,8 +60,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
"a different BB ("
<< llvm_ir::AsString(generated_value_bb->getName())
<< ") from the current insertion block ("
- << llvm_ir::AsString(ir_builder_->GetInsertBlock()->getName())
- << ").";
+ << llvm_ir::AsString(b_->GetInsertBlock()->getName()) << ").";
}
TF_ASSIGN_OR_RETURN(
@@ -77,14 +76,14 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) {
llvm::Constant* initializer =
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
llvm::GlobalVariable* global = new llvm::GlobalVariable(
- *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
+ *b_->GetInsertBlock()->getModule(), initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
/*Name=*/"");
llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast(
global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
generators_[constant] = [=](const IrArray::Index& index) {
return IrArray(shape_constant, constant->shape())
- .EmitReadArrayElement(index, ir_builder_);
+ .EmitReadArrayElement(index, b_);
};
return Status::OK();
@@ -104,7 +103,7 @@ Status FusedIrEmitter::HandleGetTupleElement(
// Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
get_tuple_element->shape(), get_tuple_element->tuple_index(),
- /*alignment=*/1, it->second, ir_builder_, module_);
+ /*alignment=*/1, it->second, b_, module_);
gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr));
// Emit code to read base tuple element array (if non-tuple shaped).
if (!ShapeUtil::IsTuple(get_tuple_element->shape())) {
@@ -112,16 +111,32 @@ Status FusedIrEmitter::HandleGetTupleElement(
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
// TODO(b/34080002) Add aliasing information to tuple element IrArray.
return IrArray(tuple_element_ptr, get_tuple_element->shape())
- .EmitReadArrayElement(index, ir_builder_);
+ .EmitReadArrayElement(index, b_);
};
}
return Status::OK();
}
Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
- generators_[parameter] = [=](const IrArray::Index& index) {
+ generators_[parameter] = [=](const IrArray::Index& index) -> llvm::Value* {
+ if (tiled_parameter_info_) {
+ if (llvm::Value* param_tile_buffer =
+ tiled_parameter_info_->GetBufferForParameter(
+ parameter->parameter_number())) {
+ // TODO(jlebar): Add AA metadata to this load. Tile buffers are global
+ // variables, so LLVM's points-to analysis doesn't help us much. And we
+ // want the AA info to be present before address spaces are inferred
+ // (which is pretty late in the pipeline), so even if we had
+ // address-space-based AA in LLVM, it wouldn't help us much here.
+ return b_->CreateLoad(
+ b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
+ tiled_parameter_info_->x(),
+ tiled_parameter_info_->y()}),
+ "tiled_buffer");
+ }
+ }
return parameter_arrays_[parameter->parameter_number()]
- .EmitReadArrayElement(index, ir_builder_);
+ .EmitReadArrayElement(index, b_);
};
// Store ir value for fusion operand associated with fusion parameter to be
// accessed by subsequent fused GetTupleElement instructions.
@@ -140,11 +155,11 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
}
generators_[tuple] =
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- llvm::Value* ret = llvm::UndefValue::get(llvm::StructType::get(
- ir_builder_->getContext(), operand_elemental_ir_types));
+ llvm::Value* ret = llvm::UndefValue::get(
+ llvm::StructType::get(b_->getContext(), operand_elemental_ir_types));
for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index));
- ret = ir_builder_->CreateInsertValue(ret, val_i, i);
+ ret = b_->CreateInsertValue(ret, val_i, i);
}
return ret;
};