aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc')
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc68
1 files changed, 20 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index 7b227ce294..83d35cb9ef 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -34,24 +34,21 @@ namespace llvm_ir {
ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index,
- llvm::Value* step, bool prevent_unrolling,
- bool prevent_vectorization)
+ llvm::Value* step, bool prevent_unrolling)
: prefix_(prefix.ToString()),
suffix_(suffix.ToString()),
start_index_(start_index),
end_index_(end_index),
step_(step),
insert_before_bb_(nullptr),
- prevent_unrolling_(prevent_unrolling),
- prevent_vectorization_(prevent_vectorization) {}
+ prevent_unrolling_(prevent_unrolling) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
tensorflow::StringPiece prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling, bool prevent_vectorization) {
- std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
- end_index, step, prevent_unrolling,
- prevent_vectorization));
+ bool prevent_unrolling) {
+ std::unique_ptr<ForLoop> loop(new ForLoop(
+ prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling));
loop->Emit(ir_builder);
return loop;
}
@@ -130,12 +127,14 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
ir_builder->CreateStore(indvar_inc, indvar_address);
llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_);
- std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder);
- if (!loop_metadata.empty()) {
- llvm::LLVMContext* ctx = &start_index_->getContext();
+ if (prevent_unrolling_) {
+ const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
+ llvm::LLVMContext* ctx = &back_branch->getContext();
+
auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
- loop_metadata.insert(loop_metadata.begin(), temp_node.get());
- auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
+ auto no_unroll_node = llvm::MDNode::get(
+ *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)});
+ auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node});
loop_id->replaceOperandWith(0, loop_id);
back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
}
@@ -144,27 +143,6 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
ir_builder->SetInsertPoint(exit_bb_);
}
-std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
- llvm::IRBuilder<>* ir_builder) {
- const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
- const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
- llvm::LLVMContext* ctx = &start_index_->getContext();
-
- std::vector<llvm::Metadata*> result;
- if (prevent_unrolling_) {
- result.push_back(llvm::MDNode::get(
- *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
- }
-
- if (prevent_vectorization_) {
- result.push_back(llvm::MDNode::get(
- *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
- llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
- }
-
- return result;
-}
-
string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
}
@@ -178,25 +156,23 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
- bool prevent_unrolling,
- bool prevent_vectorization) {
+ bool prevent_unrolling) {
return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
- prevent_unrolling, prevent_vectorization);
+ prevent_unrolling);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
llvm::Value* stride,
- bool prevent_unrolling,
- bool prevent_vectorization) {
+ bool prevent_unrolling) {
if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one.
ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
}
std::unique_ptr<ForLoop> loop(new ForLoop(
/*prefix=*/name_, suffix, start_index, end_index, stride,
- prevent_unrolling, prevent_vectorization));
+ prevent_unrolling));
loop->Emit(ir_builder_);
if (outer_loop_preheader_bb_ == nullptr) {
@@ -215,24 +191,20 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index,
tensorflow::StringPiece suffix,
- bool prevent_unrolling,
- bool prevent_vectorization) {
+ bool prevent_unrolling) {
CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index),
- ir_builder_->getInt64(end_index), prevent_unrolling,
- prevent_vectorization);
+ ir_builder_->getInt64(end_index), prevent_unrolling);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride,
tensorflow::StringPiece suffix,
- bool prevent_unrolling,
- bool prevent_vectorization) {
+ bool prevent_unrolling) {
CHECK_LE(start_index, end_index);
return AddLoop(suffix, ir_builder_->getInt64(start_index),
ir_builder_->getInt64(end_index),
- ir_builder_->getInt64(stride), prevent_unrolling,
- prevent_vectorization);
+ ir_builder_->getInt64(stride), prevent_unrolling);
}
IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,