diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc | 68 |
1 files changed, 20 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 7b227ce294..83d35cb9ef 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,24 +34,21 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling, - bool prevent_vectorization) + llvm::Value* step, bool prevent_unrolling) : prefix_(prefix.ToString()), suffix_(suffix.ToString()), start_index_(start_index), end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling), - prevent_vectorization_(prevent_vectorization) {} + prevent_unrolling_(prevent_unrolling) {} /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling, bool prevent_vectorization) { - std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index, - end_index, step, prevent_unrolling, - prevent_vectorization)); + bool prevent_unrolling) { + std::unique_ptr<ForLoop> loop(new ForLoop( + prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling)); loop->Emit(ir_builder); return loop; } @@ -130,12 +127,14 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->CreateStore(indvar_inc, indvar_address); llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); - std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder); - if (!loop_metadata.empty()) { - llvm::LLVMContext* ctx = &start_index_->getContext(); + if (prevent_unrolling_) { + const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + llvm::LLVMContext* ctx = &back_branch->getContext(); + auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); - loop_metadata.insert(loop_metadata.begin(), temp_node.get()); - auto loop_id = llvm::MDNode::get(*ctx, loop_metadata); + auto no_unroll_node = llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}); + auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node}); loop_id->replaceOperandWith(0, loop_id); back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); } @@ -144,27 +143,6 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->SetInsertPoint(exit_bb_); } -std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata( - llvm::IRBuilder<>* ir_builder) { - const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; - const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; - llvm::LLVMContext* ctx = &start_index_->getContext(); - - std::vector<llvm::Metadata*> result; - if (prevent_unrolling_) { - result.push_back(llvm::MDNode::get( - *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); - } - - if (prevent_vectorization_) { - result.push_back(llvm::MDNode::get( - *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName), - llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); - } - - return result; -} - string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } @@ -178,25 +156,23 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling, - bool prevent_vectorization) { + bool prevent_unrolling) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling, prevent_vectorization); + prevent_unrolling); } std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling, - bool prevent_vectorization) { + bool prevent_unrolling) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr<ForLoop> loop(new ForLoop( /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling, prevent_vectorization)); + prevent_unrolling)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -215,24 +191,20 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling, - bool prevent_vectorization) { + bool prevent_unrolling) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling, - prevent_vectorization); + ir_builder_->getInt64(end_index), prevent_unrolling); } std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling, - bool prevent_vectorization) { + bool prevent_unrolling) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling, - prevent_vectorization); + ir_builder_->getInt64(stride), prevent_unrolling); } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, |