aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc')
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc49
1 files changed, 22 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 5fc08aab91..11ed6ee59f 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -31,12 +31,12 @@ namespace llvm_ir {
void EmitTupleSelect(const IrArray& select, const IrArray& pred,
llvm::Value* on_true, llvm::Value* on_false,
- llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
+ llvm::IRBuilder<>* b, llvm::Module* module) {
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
llvm::LoadInst* pred_value =
- ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
- llvm::Value* pred_cond = ir_builder->CreateICmpNE(
+ b->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
+ llvm::Value* pred_cond = b->CreateICmpNE(
pred_value,
llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0),
"boolean_predicate");
@@ -46,47 +46,42 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
VLOG(2) << " pred_cond: " << DumpToString(*pred_cond);
for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) {
- llvm::Value* const element_index[] = {ir_builder->getInt64(0),
- ir_builder->getInt64(i)};
+ llvm::Value* const element_index[] = {b->getInt64(0), b->getInt64(i)};
llvm::Value* on_true_element_address =
- ir_builder->CreateInBoundsGEP(on_true, element_index);
- llvm::Value* on_true_element = ir_builder->CreateLoad(
+ b->CreateInBoundsGEP(on_true, element_index);
+ llvm::Value* on_true_element = b->CreateLoad(
on_true_element_address, "on_true_element_" + llvm::Twine(i));
llvm::Value* on_false_element_address =
- ir_builder->CreateInBoundsGEP(on_false, element_index);
- llvm::Value* on_false_element = ir_builder->CreateLoad(
+ b->CreateInBoundsGEP(on_false, element_index);
+ llvm::Value* on_false_element = b->CreateLoad(
on_false_element_address, "on_false_element_" + llvm::Twine(i));
llvm::Value* output_element_address =
- ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index);
- ir_builder->CreateStore(
- ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element,
- "select_output_element_" + llvm::Twine(i)),
- output_element_address);
+ b->CreateInBoundsGEP(select.GetBasePointer(), element_index);
+ b->CreateStore(b->CreateSelect(pred_cond, on_true_element, on_false_element,
+ "select_output_element_" + llvm::Twine(i)),
+ output_element_address);
}
}
void EmitTuple(const IrArray& tuple,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
- llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
+ llvm::IRBuilder<>* b, llvm::Module* module) {
for (size_t i = 0; i < operands.size(); ++i) {
- auto* store = ir_builder->CreateStore(
- ir_builder->CreatePointerCast(operands[i],
- PrimitiveTypeToIrType(TUPLE, module)),
- ir_builder->CreateInBoundsGEP(
- tuple.GetBasePointer(),
- {ir_builder->getInt64(0), ir_builder->getInt64(i)}));
+ auto* store = b->CreateStore(
+ b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)),
+ b->CreateInBoundsGEP(tuple.GetBasePointer(),
+ {b->getInt64(0), b->getInt64(i)}));
tuple.AnnotateLoadStoreInstructionWithMetadata(store);
}
}
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
- llvm::IRBuilder<>* ir_builder,
- llvm::Module* module) {
- llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP(
- operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)});
- llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr);
+ llvm::IRBuilder<>* b, llvm::Module* module) {
+ llvm::Value* element_ptr =
+ b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)});
+ llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr);
// Mark the loaded pointer as dereferenceable if we know its shape.
if (!ShapeUtil::IsOpaque(target_shape)) {
@@ -98,7 +93,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
llvm::Type* element_type = ShapeToIrType(target_shape, module);
llvm::Value* ret_val =
- ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo());
+ b->CreateBitCast(src_buffer, element_type->getPointerTo());
return ret_val;
}