aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/vector_support_library.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/vector_support_library.h')
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.h36
1 files changed, 16 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index 49c2a4e2f4..c728f6df0a 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -46,11 +46,11 @@ class VectorSupportLibrary {
// instance (i.e. LoadVector will load a vector of type <`vector_size` x
// `primitive_type`>).
VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
- llvm::IRBuilder<>* ir_builder, std::string name);
+ llvm::IRBuilder<>* b, std::string name);
llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
- return Mul(ir_builder()->getInt64(lhs), rhs);
+ return Mul(b()->getInt64(lhs), rhs);
}
llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
@@ -63,7 +63,7 @@ class VectorSupportLibrary {
llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
- return Add(ir_builder()->getInt64(lhs), rhs);
+ return Add(b()->getInt64(lhs), rhs);
}
llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
@@ -147,13 +147,11 @@ class VectorSupportLibrary {
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
llvm::Value* offset_elements, int64 scale) {
return ComputeOffsetPointer(
- base_pointer,
- ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements));
+ base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
}
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
int64 offset_elements) {
- return ComputeOffsetPointer(base_pointer,
- ir_builder()->getInt64(offset_elements));
+ return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
}
llvm::Value* LoadVector(llvm::Value* pointer);
@@ -164,7 +162,7 @@ class VectorSupportLibrary {
}
llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
- return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements));
+ return LoadVector(base_pointer, b()->getInt64(offset_elements));
}
llvm::Value* LoadScalar(llvm::Value* pointer);
@@ -175,7 +173,7 @@ class VectorSupportLibrary {
}
llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
- return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements));
+ return LoadScalar(base_pointer, b()->getInt64(offset_elements));
}
void StoreVector(llvm::Value* value, llvm::Value* pointer);
@@ -187,7 +185,7 @@ class VectorSupportLibrary {
void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
int64 offset_elements) {
- StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements));
+ StoreVector(value, base_pointer, b()->getInt64(offset_elements));
}
void StoreScalar(llvm::Value* value, llvm::Value* pointer);
@@ -198,7 +196,7 @@ class VectorSupportLibrary {
void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
int64 offset_elements) {
- StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements));
+ StoreScalar(base_pointer, b()->getInt64(offset_elements));
}
llvm::Value* LoadBroadcast(llvm::Value* pointer);
@@ -207,7 +205,7 @@ class VectorSupportLibrary {
return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
}
llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
- return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements));
+ return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
}
// Compute the horizontal sum of each vector in `vectors`. The i'th element
@@ -220,7 +218,7 @@ class VectorSupportLibrary {
llvm::Value* GetZeroVector();
llvm::Value* GetZeroScalar();
- llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
+ llvm::IRBuilder<>* b() const { return b_; }
int64 vector_size() const { return vector_size_; }
llvm::Type* vector_type() const { return vector_type_; }
llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
@@ -277,7 +275,7 @@ class VectorSupportLibrary {
int64 vector_size_;
PrimitiveType primitive_type_;
- llvm::IRBuilder<>* ir_builder_;
+ llvm::IRBuilder<>* b_;
llvm::Type* vector_type_;
llvm::Type* vector_pointer_type_;
llvm::Type* scalar_type_;
@@ -289,22 +287,21 @@ class VectorSupportLibrary {
// can later convert to a SSA value.
class LlvmVariable {
public:
- LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder);
+ LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
llvm::Value* Get() const;
void Set(llvm::Value* new_value);
private:
llvm::AllocaInst* alloca_;
- llvm::IRBuilder<>* ir_builder_;
+ llvm::IRBuilder<>* b_;
};
class VectorVariable : public LlvmVariable {
public:
VectorVariable(VectorSupportLibrary* vector_support,
llvm::Value* initial_value)
- : LlvmVariable(vector_support->vector_type(),
- vector_support->ir_builder()) {
+ : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
Set(initial_value);
}
};
@@ -313,8 +310,7 @@ class ScalarVariable : public LlvmVariable {
public:
ScalarVariable(VectorSupportLibrary* vector_support,
llvm::Value* initial_value)
- : LlvmVariable(vector_support->scalar_type(),
- vector_support->ir_builder()) {
+ : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
Set(initial_value);
}
};