aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc37
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc48
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc12
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h7
-rw-r--r--tensorflow/compiler/xla/xla.proto3
14 files changed, 121 insertions, 39 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index f037663e3f..70ae95bf47 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -43,7 +43,7 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
#ifdef INTEL_MKL
flags->set_xla_cpu_use_mkl_dnn(true);
#endif // INTEL_MKL
-
+ flags->set_xla_gpu_max_kernel_unroll_factor(1);
// Set cudnn batchnorm off by default; it does not provide a performance win
// on average.
flags->set_xla_gpu_use_cudnn_batchnorm(false);
@@ -224,6 +224,11 @@ void AllocateFlags() {
flag_values->xla_gpu_disable_multi_streaming(),
"If true, multi-streaming in the GPU backend is disabled."),
tensorflow::Flag(
+ "xla_gpu_max_kernel_unroll_factor",
+ int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
+ flag_values->xla_gpu_max_kernel_unroll_factor(),
+ "Specify the maximum kernel unroll factor for the GPU backend."),
+ tensorflow::Flag(
"xla_dump_optimized_hlo_proto_to",
flag_values->mutable_xla_dump_optimized_hlo_proto_to(),
"Dump Hlo after all hlo passes are executed as proto binary into "
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
index 1e439cde11..54af40506d 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -29,7 +29,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
: LoopEmitter(target_element_generator, target_array, ir_builder),
dynamic_loop_bounds_(dynamic_loop_bounds) {}
-llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
+std::vector<llvm_ir::IrArray::Index>
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) {
CHECK(!ShapeUtil::IsTuple(shape_));
CHECK(!ShapeUtil::IsScalar(shape_));
@@ -69,7 +70,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
CHECK(exit_bb_ != nullptr);
- return array_index;
+ return {array_index};
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index ce92e36a94..755715634a 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -60,7 +60,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
~ParallelLoopEmitter() override = default;
- llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock(
+ std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index d29cc21ab1..26e497762f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -536,7 +536,27 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
+
+ int max_unroll_factor = fusion->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_max_kernel_unroll_factor();
+
+ // Find the largest possible power of two to unroll by.
+ // TODO(kramerb): Make this smarter.
+ int unroll_factor = 1;
+ if (!fusion->IsMultiOutputFusion()) {
+ CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
+ int64 num_elements = ShapeUtil::ElementsIn(fusion->shape());
+ for (int i = max_unroll_factor; i > 1; i /= 2) {
+ if (num_elements % i == 0) {
+ unroll_factor = i;
+ break;
+ }
+ }
+ }
+
+ thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor));
return IrEmitter::HandleFusion(fusion);
}
@@ -2021,7 +2041,7 @@ Status IrEmitterUnnested::HandleGather(HloInstruction* gather) {
}
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
- const HloInstruction* inst) {
+ const HloInstruction* inst, int unroll_factor) {
const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment();
@@ -2113,7 +2133,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
}
return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
- inst);
+ inst, unroll_factor);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
@@ -2485,21 +2505,28 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
+ int unroll_factor = thunk->unroll_factor();
VLOG(3) << bindings_.ToString();
const Shape& element_shape = hlo.IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo.shape(), {0})
: hlo.shape();
+ VLOG(3) << "EmitTargetElementLoopInThunk "
+ << ShapeUtil::HumanStringWithLayout(hlo.shape())
+ << " for unroll_factor " << unroll_factor;
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- element_shape, ir_emitter_context_->device_description());
+ element_shape, ir_emitter_context_->device_description(), unroll_factor);
UpdateLaunchDimensions(launch_dimensions, thunk,
ir_emitter_context_->llvm_module());
if (!hlo.IsMultiOutputFusion()) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &ir_builder_, unroll_factor)
.EmitLoop(IrName(&hlo));
}
+ CHECK_EQ(unroll_factor, 1)
+ << "multi-output fusion does not support unrolling";
+
// For multiple outputs fusion, we need to emit each operand and the root.
std::vector<llvm_ir::IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 66c62e2d2d..b842f480c6 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -150,8 +150,10 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
- // Thunk object.
- std::unique_ptr<KernelThunk> BuildKernelThunk(const HloInstruction* inst);
+ // Thunk object. The kernel implementation will be unrolled if unroll_factor
+ // is greater than one.
+ std::unique_ptr<KernelThunk> BuildKernelThunk(const HloInstruction* inst,
+ int unroll_factor = 1);
// Returns a FftThunk that calls cuFFT to implement `inst`.
std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index c20a781a33..c24dc1457f 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -30,10 +30,12 @@ namespace gpu {
KernelThunk::KernelThunk(
tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
- const string& kernel_name, const HloInstruction* hlo_instruction)
+ const string& kernel_name, const HloInstruction* hlo_instruction,
+ int unroll_factor)
: Thunk(Kind::kKernel, hlo_instruction),
args_(args.begin(), args.end()),
- kernel_name_(kernel_name) {}
+ kernel_name_(kernel_name),
+ unroll_factor_(unroll_factor) {}
tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) {
tensorflow::mutex_lock lock(mutex_);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index 9ae455e2fc..df8971b083 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -47,12 +47,14 @@ class KernelThunk : public Thunk {
//
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
KernelThunk(tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
- const string& kernel_name, const HloInstruction* hlo_instruction);
+ const string& kernel_name, const HloInstruction* hlo_instruction,
+ int unroll_factor);
KernelThunk(const KernelThunk&) = delete;
KernelThunk& operator=(const KernelThunk&) = delete;
~KernelThunk() override = default;
const string& kernel_name() const { return kernel_name_; }
+ int unroll_factor() const { return unroll_factor_; }
void SetLaunchDimensions(const LaunchDimensions& launch_dims);
tensorflow::Status Initialize(const GpuExecutable& executable) override;
@@ -69,6 +71,10 @@ class KernelThunk : public Thunk {
// Entry kernel name for the computation.
const string kernel_name_;
+ // The number of times this kernel should be unrolled. This works as a
+ // multiplier on the number of elements produced by a GPU thread.
+ const int unroll_factor_;
+
// The thread and block dimension used to launch the kernel.
// Will be set by IrEmitterUnnested.
LaunchDimensions launch_dimensions_;
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index 388dcc008b..d8c07dc311 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -32,25 +32,32 @@ namespace gpu {
ParallelLoopEmitter::ParallelLoopEmitter(
BodyEmitter body_emitter, const Shape& shape,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor)
: LoopEmitter(body_emitter, shape, ir_builder),
- launch_dimensions_(launch_dimensions) {}
+ launch_dimensions_(launch_dimensions),
+ unroll_factor_(unroll_factor) {}
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor)
: LoopEmitter(target_element_generator, target_arrays, ir_builder),
- launch_dimensions_(launch_dimensions) {}
+ launch_dimensions_(launch_dimensions),
+ unroll_factor_(unroll_factor) {}
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
const llvm_ir::IrArray& target_array,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor)
: LoopEmitter(target_element_generator, target_array, ir_builder),
- launch_dimensions_(launch_dimensions) {}
+ launch_dimensions_(launch_dimensions),
+ unroll_factor_(unroll_factor) {}
-llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
+std::vector<llvm_ir::IrArray::Index>
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) {
// Emit the following code in LLVM IR:
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
@@ -63,6 +70,9 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
// "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x"
//
// %nctaid.x is currently specified as 2147483647.
+ VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_;
+ std::vector<llvm_ir::IrArray::Index> array_indices;
+
llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(),
@@ -81,7 +91,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(),
"thread_id");
- llvm::Value* linear_index = ir_builder_->CreateAdd(
+ llvm::Value* linear_index_base = ir_builder_->CreateAdd(
ir_builder_->CreateMul(
block_id,
ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "",
@@ -99,15 +109,30 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::assume,
{ir_builder_->CreateICmpULT(
- linear_index,
+ linear_index_base,
ir_builder_->getInt64(launch_dimensions_.threads_per_block() *
launch_dimensions_.block_count()),
"linear_index_in_range")},
{}, ir_builder_);
+ if (unroll_factor_ > 1) {
+ linear_index_base = ir_builder_->CreateMul(
+ linear_index_base, ir_builder_->getInt64(unroll_factor_),
+ "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
+ }
+
+ array_indices.emplace_back(linear_index_base, shape_, ir_builder_);
+ for (int i = 1; i < unroll_factor_; ++i) {
+ llvm::Value* linear_index = ir_builder_->CreateAdd(
+ linear_index_base, ir_builder_->getInt64(i), "linear_index",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ array_indices.emplace_back(linear_index, shape_, ir_builder_);
+ }
+
auto if_in_bounds = llvm_ir::EmitIfThenElse(
ir_builder_->CreateICmpULT(
- linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))),
+ linear_index_base,
+ ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))),
llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false);
// Set exit_bb_ to the exit block of the if structure.
@@ -116,7 +141,8 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
// Set IR builder insertion point to the body of the if structure.
llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
- return llvm_ir::IrArray::Index(linear_index, shape_, ir_builder_);
+
+ return array_indices;
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index 8ed63a854a..25318b3bed 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -34,13 +34,13 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
// The meanings of other parameters are the same as LoopEmitter.
ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape,
const LaunchDimensions& launch_dimensions,
- llvm::IRBuilder<>* ir_builder);
+ llvm::IRBuilder<>* ir_builder, int unroll_factor = 1);
// Constructs a ParallelLoopEmitter from an element generator that generates
// each element of the given target array.
ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
const llvm_ir::IrArray& target_array,
const LaunchDimensions& launch_dimensions,
- llvm::IRBuilder<>* ir_builder);
+ llvm::IRBuilder<>* ir_builder, int unroll_factor = 1);
// Constructs a loop emitter for a loop that generates on element of each of N
// arrays on each iteration.
@@ -50,18 +50,20 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder);
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor = 1);
ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
~ParallelLoopEmitter() override = default;
- llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock(
+ std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) override;
private:
// The thread and block dimension to parallelize the loop on.
const LaunchDimensions launch_dimensions_;
+ const int unroll_factor_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
index 6cf280df05..5283d51cd1 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -44,12 +44,16 @@ std::ostream& operator<<(std::ostream& out,
// Calculates the launch dimensions used to invoke `hlo`.
LaunchDimensions CalculateLaunchDimensions(
- const Shape& shape, const se::DeviceDescription& device_desc) {
+ const Shape& shape, const se::DeviceDescription& device_desc,
+ int unroll_factor) {
int64 num_elements = ShapeUtil::ElementsIn(shape);
if (num_elements <= 1) {
return LaunchDimensions();
}
+ CHECK_EQ(num_elements % unroll_factor, 0);
+ num_elements = num_elements / unroll_factor;
+
// Since we don't do any inter-warp communication, we're free to choose any
// block size we want, subject to hardware constraints. We choose the
// smallest block size that allows the GPU to reach full occupancy (assuming
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
index 0bf463a6ef..42d2d2af2e 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
@@ -58,7 +58,8 @@ std::ostream& operator<<(std::ostream& out,
LaunchDimensions CalculateLaunchDimensions(
const Shape& shape,
- const perftools::gputools::DeviceDescription& device_desc);
+ const perftools::gputools::DeviceDescription& device_desc,
+ int unroll_factor = 1);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index b6b918ec78..3978acc132 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -88,12 +88,12 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
}
}
-IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock(
+std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) {
if (ShapeUtil::IsScalar(shape_)) {
// No loop needed, so set exit_bb_ to nullptr.
exit_bb_ = nullptr;
- return IrArray::Index();
+ return {IrArray::Index()};
}
// Create loop nest with one for-loop for each dimension of the target shape.
@@ -121,12 +121,14 @@ IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock(
exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
CHECK_NOTNULL(exit_bb_);
- return array_index;
+ return {array_index};
}
tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) {
- IrArray::Index array_index = EmitIndexAndSetExitBasicBlock(loop_name);
- TF_RETURN_IF_ERROR(body_emitter_(array_index));
+ for (const IrArray::Index& array_index :
+ EmitIndexAndSetExitBasicBlock(loop_name)) {
+ TF_RETURN_IF_ERROR(body_emitter_(array_index));
+ }
// Set the insertion point of ir_builder_ to the loop exit, so that
// code emitted for later instructions will be correctly placed.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index 0fc528439a..9ff497aecd 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -63,11 +63,12 @@ class LoopEmitter {
// Emits a loop nest (with a yet-to-be-filled loop body) that iterates through
// every element in the given shape. Returns the multi-dimensional index that
- // specifies the element.
- IrArray::Index EmitIndexAndSetExitBasicBlock() {
+ // specifies the element, will return multiple indices if the loop is
+ // unrolled.
+ std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock() {
return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"");
}
- virtual IrArray::Index EmitIndexAndSetExitBasicBlock(
+ virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name);
// Emits a complete loop nest for every element in the given shape.
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index b4cbdf3773..f619b8dc24 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -192,6 +192,9 @@ message DebugOptions {
// Generate calls to MKL-DNN in the CPU backend.
bool xla_cpu_use_mkl_dnn = 97;
+ // Maximum kernel unroll factor for the GPU backend.
+ int32 xla_gpu_max_kernel_unroll_factor = 98;
+
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;