diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc | 49 |
1 files changed, 22 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 5fc08aab91..11ed6ee59f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -31,12 +31,12 @@ namespace llvm_ir { void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + llvm::IRBuilder<>* b, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = - ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder->CreateICmpNE( + b->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = b->CreateICmpNE( pred_value, llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0), "boolean_predicate"); @@ -46,47 +46,42 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - llvm::Value* const element_index[] = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {b->getInt64(0), b->getInt64(i)}; llvm::Value* on_true_element_address = - ir_builder->CreateInBoundsGEP(on_true, element_index); - llvm::Value* on_true_element = ir_builder->CreateLoad( + b->CreateInBoundsGEP(on_true, element_index); + llvm::Value* on_true_element = b->CreateLoad( on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = - ir_builder->CreateInBoundsGEP(on_false, element_index); - llvm::Value* on_false_element = ir_builder->CreateLoad( + b->CreateInBoundsGEP(on_false, element_index); + llvm::Value* on_false_element = b->CreateLoad( on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = - ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); - ir_builder->CreateStore( - ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, - "select_output_element_" + llvm::Twine(i)), - output_element_address); + b->CreateInBoundsGEP(select.GetBasePointer(), element_index); + b->CreateStore(b->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), + output_element_address); } } void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice<llvm::Value*> operands, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + llvm::IRBuilder<>* b, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { - auto* store = ir_builder->CreateStore( - ir_builder->CreatePointerCast(operands[i], - PrimitiveTypeToIrType(TUPLE, module)), - ir_builder->CreateInBoundsGEP( - tuple.GetBasePointer(), - {ir_builder->getInt64(0), ir_builder->getInt64(i)})); + auto* store = b->CreateStore( + b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)), + b->CreateInBoundsGEP(tuple.GetBasePointer(), + {b->getInt64(0), b->getInt64(i)})); tuple.AnnotateLoadStoreInstructionWithMetadata(store); } } llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { - llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( - operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); - llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); + llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::Value* element_ptr = + b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)}); + llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. if (!ShapeUtil::IsOpaque(target_shape)) { @@ -98,7 +93,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, llvm::Type* element_type = ShapeToIrType(target_shape, module); llvm::Value* ret_val = - ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); + b->CreateBitCast(src_buffer, element_type->getPointerTo()); return ret_val; } |