aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/vector_support_library.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc163
1 files changed, 74 insertions, 89 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index c444d15185..3274be8d9d 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -23,14 +23,14 @@ namespace xla {
namespace cpu {
VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
int64 vector_size,
- llvm::IRBuilder<>* ir_builder,
+ llvm::IRBuilder<>* b,
std::string name)
: vector_size_(vector_size),
primitive_type_(primitive_type),
- ir_builder_(ir_builder),
+ b_(b),
name_(std::move(name)) {
scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
- primitive_type, ir_builder_->GetInsertBlock()->getModule());
+ primitive_type, b_->GetInsertBlock()->getModule());
scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
@@ -63,9 +63,9 @@ llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
llvm::Value* rhs) {
if (scalar_type_->isFloatingPointTy()) {
- return ir_builder()->CreateFMul(lhs, rhs, name());
+ return b()->CreateFMul(lhs, rhs, name());
} else {
- return ir_builder()->CreateMul(lhs, rhs, name());
+ return b()->CreateMul(lhs, rhs, name());
}
}
@@ -76,13 +76,13 @@ llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
- return ir_builder()->CreateFSub(lhs, rhs);
+ return b()->CreateFSub(lhs, rhs);
}
llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
if (scalar_type_->isFloatingPointTy()) {
- return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_);
+ return llvm_ir::EmitFloatMax(lhs, rhs, b_);
} else {
LOG(FATAL) << "Max for integers is unimplemented";
}
@@ -91,13 +91,13 @@ llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
AssertCorrectTypes({a});
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
- {a->getType()}, ir_builder());
+ {a->getType()}, b());
}
llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
if (scalar_type_->isFloatingPointTy()) {
- return ir_builder()->CreateFDiv(lhs, rhs, name());
+ return b()->CreateFDiv(lhs, rhs, name());
} else {
LOG(FATAL) << "Division for integers is unimplemented";
}
@@ -111,42 +111,41 @@ llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
CHECK(scalar_type_->isFloatingPointTy());
return llvm_ir::EmitFloatMin(
- llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_),
- GetConstantFloat(type, high), ir_builder_);
+ llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_),
+ GetConstantFloat(type, high), b_);
}
llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
- return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name()));
+ return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
}
llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
- return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name()));
+ return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
}
llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
- return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name()));
+ return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
}
llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
- return ir_builder()->CreateBitCast(
- ir_builder()->CreateSExt(i1, integer_type, name()),
- is_vector ? vector_type() : scalar_type(), name());
+ return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
+ is_vector ? vector_type() : scalar_type(), name());
}
llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
CHECK(scalar_type()->isFloatingPointTy());
const llvm::DataLayout& data_layout =
- ir_builder()->GetInsertBlock()->getModule()->getDataLayout();
+ b()->GetInsertBlock()->getModule()->getDataLayout();
int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
- llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits);
+ llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
if (vector) {
return llvm::VectorType::get(scalar_int_type, vector_size());
} else {
@@ -156,7 +155,7 @@ llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
CHECK_EQ(x->getType(), scalar_type());
- return ir_builder()->CreateVectorSplat(vector_size(), x, name());
+ return b()->CreateVectorSplat(vector_size(), x, name());
}
llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
@@ -164,10 +163,9 @@ llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
AssertCorrectTypes({lhs, rhs});
llvm::Type* int_type =
IntegerTypeForFloatSize(lhs->getType() == vector_type());
- return ir_builder()->CreateBitCast(
- ir_builder()->CreateAnd(
- ir_builder()->CreateBitCast(lhs, int_type, name()),
- ir_builder()->CreateBitCast(rhs, int_type, name()), name()),
+ return b()->CreateBitCast(
+ b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
+ b()->CreateBitCast(rhs, int_type, name()), name()),
vector_type());
}
@@ -175,9 +173,8 @@ llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
AssertCorrectTypes({lhs});
llvm::Type* int_type =
IntegerTypeForFloatSize(lhs->getType() == vector_type());
- return ir_builder()->CreateBitCast(
- ir_builder()->CreateNot(
- ir_builder()->CreateBitCast(lhs, int_type, name()), name()),
+ return b()->CreateBitCast(
+ b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
vector_type());
}
@@ -185,47 +182,43 @@ llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
llvm::Type* int_type =
IntegerTypeForFloatSize(lhs->getType() == vector_type());
- return ir_builder()->CreateBitCast(
- ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()),
- ir_builder()->CreateBitCast(rhs, int_type, name()),
- name()),
+ return b()->CreateBitCast(
+ b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
+ b()->CreateBitCast(rhs, int_type, name()), name()),
vector_type(), name());
}
llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
llvm::Value* rhs) {
if (scalar_type_->isFloatingPointTy()) {
- return ir_builder()->CreateFAdd(lhs, rhs, name());
+ return b()->CreateFAdd(lhs, rhs, name());
} else {
- return ir_builder()->CreateAdd(lhs, rhs, name());
+ return b()->CreateAdd(lhs, rhs, name());
}
}
llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
llvm::Value* base_pointer, llvm::Value* offset_elements) {
if (base_pointer->getType() != scalar_pointer_type()) {
- base_pointer = ir_builder()->CreateBitCast(base_pointer,
- scalar_pointer_type(), name());
+ base_pointer =
+ b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
}
- return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements},
- name());
+ return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name());
}
llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
if (pointer->getType() != vector_pointer_type()) {
- pointer =
- ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name());
+ pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
}
- return ir_builder()->CreateAlignedLoad(
+ return b()->CreateAlignedLoad(
pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
}
llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
if (pointer->getType() != scalar_pointer_type()) {
- pointer =
- ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+ pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
- return ir_builder()->CreateAlignedLoad(
+ return b()->CreateAlignedLoad(
pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
}
@@ -233,30 +226,28 @@ void VectorSupportLibrary::StoreVector(llvm::Value* value,
llvm::Value* pointer) {
AssertCorrectTypes({value});
if (pointer->getType() != vector_pointer_type()) {
- pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type());
+ pointer = b()->CreateBitCast(pointer, vector_pointer_type());
}
- ir_builder()->CreateAlignedStore(
- value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+ b()->CreateAlignedStore(value, pointer,
+ ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
}
void VectorSupportLibrary::StoreScalar(llvm::Value* value,
llvm::Value* pointer) {
AssertCorrectTypes({value});
if (pointer->getType() != scalar_pointer_type()) {
- pointer =
- ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+ pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
- ir_builder()->CreateAlignedStore(
- value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+ b()->CreateAlignedStore(value, pointer,
+ ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
}
llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
if (pointer->getType() != scalar_pointer_type()) {
- pointer =
- ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+ pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
}
- return ir_builder()->CreateVectorSplat(
- vector_size(), ir_builder()->CreateLoad(pointer), name());
+ return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer),
+ name());
}
llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
@@ -267,20 +258,19 @@ llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
for (unsigned j = 0; j < vector_size(); ++j) {
if (j < (i / 2)) {
- mask[j] = ir_builder()->getInt32(i / 2 + j);
+ mask[j] = b()->getInt32(i / 2 + j);
} else {
- mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty());
+ mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
}
}
- llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector(
- vector, llvm::UndefValue::get(vector_type()),
- llvm::ConstantVector::get(mask), "");
+ llvm::Value* half_remaining_lanes =
+ b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
+ llvm::ConstantVector::get(mask), "");
vector = Add(vector, half_remaining_lanes);
}
- return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0),
- name());
+ return b()->CreateExtractElement(vector, b()->getInt32(0), name());
}
llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
@@ -307,19 +297,19 @@ llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
// vector, which are the lanes 2 and 3 in the rhs vector.
for (int i = 0; i < vector_size(); i += 2) {
int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
- mask_a.push_back(ir_builder()->getInt32(increment + i));
- mask_b.push_back(ir_builder()->getInt32(increment + i + 1));
+ mask_a.push_back(b()->getInt32(increment + i));
+ mask_b.push_back(b()->getInt32(increment + i + 1));
}
for (int i = 0; i < vector_size(); i += 2) {
int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
- mask_a.push_back(ir_builder()->getInt32(increment + i));
- mask_b.push_back(ir_builder()->getInt32(increment + i + 1));
+ mask_a.push_back(b()->getInt32(increment + i));
+ mask_b.push_back(b()->getInt32(increment + i + 1));
}
- llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector(
- lhs, rhs, llvm::ConstantVector::get(mask_a));
- llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector(
- lhs, rhs, llvm::ConstantVector::get(mask_b));
+ llvm::Value* shuffle_0 =
+ b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
+ llvm::Value* shuffle_1 =
+ b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
return Add(shuffle_0, shuffle_1);
}
@@ -327,23 +317,21 @@ llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
llvm::SmallVector<llvm::Constant*, 32> mask;
for (int i = 0; i < vector_size() / 2; i++) {
- mask.push_back(ir_builder()->getInt32(i));
+ mask.push_back(b()->getInt32(i));
}
- return ir_builder()->CreateShuffleVector(vector,
- llvm::UndefValue::get(vector_type()),
- llvm::ConstantVector::get(mask));
+ return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
+ llvm::ConstantVector::get(mask));
}
llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
llvm::SmallVector<llvm::Constant*, 32> mask;
for (int i = 0; i < vector_size() / 2; i++) {
- mask.push_back(ir_builder()->getInt32(i + vector_size() / 2));
+ mask.push_back(b()->getInt32(i + vector_size() / 2));
}
- return ir_builder()->CreateShuffleVector(vector,
- llvm::UndefValue::get(vector_type()),
- llvm::ConstantVector::get(mask));
+ return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
+ llvm::ConstantVector::get(mask));
}
std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
@@ -360,8 +348,8 @@ std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
[this](llvm::Value* vector) { return AddReduce(vector); });
if (init_values) {
for (int64 i = 0, e = result.size(); i < e; i++) {
- result[i] = Add(result[i], ir_builder()->CreateExtractElement(
- init_values, ir_builder()->getInt32(i)));
+ result[i] = Add(result[i],
+ b()->CreateExtractElement(init_values, b()->getInt32(i)));
}
}
return result;
@@ -398,9 +386,9 @@ VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
std::vector<llvm::Value*> results;
for (int i = 0; i < lane_width; i++) {
- llvm::Value* scalar_result = ir_builder()->CreateExtractElement(
- i < (lane_width / 2) ? low : high,
- ir_builder()->getInt32(i % (lane_width / 2)), name());
+ llvm::Value* scalar_result =
+ b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
+ b()->getInt32(i % (lane_width / 2)), name());
results.push_back(scalar_result);
}
@@ -415,17 +403,14 @@ llvm::Value* VectorSupportLibrary::GetZeroScalar() {
return llvm::Constant::getNullValue(scalar_type());
}
-LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder)
- : ir_builder_(ir_builder) {
- alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_);
+LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
+ alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
}
-llvm::Value* LlvmVariable::Get() const {
- return ir_builder_->CreateLoad(alloca_);
-}
+llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); }
void LlvmVariable::Set(llvm::Value* new_value) {
- ir_builder_->CreateStore(new_value, alloca_);
+ b_->CreateStore(new_value, alloca_);
}
TileVariable::TileVariable(VectorSupportLibrary* vector_support,