aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-04-10 09:26:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 09:34:47 -0700
commit40f85affa388288a66d5fb9b2295155216e68b94 (patch)
treea33755413bc7e6313cf89b12787b9fe6c4320ffd
parentd7e4458c3ca839fd8f5a86b4342905ce511a47eb (diff)
[XLA:GPU] Add infrastructure for unrolling kernels to improve bandwidth utilization.
We often have simple kernels that do very little actual work, duplicating that can increase the used bandwidth. This change introduces flags and infrastructure for unrolling kernels, it doesn't include any cost heuristics and is disabled by default. Based on code written by Bixia Zheng. PiperOrigin-RevId: 192296781
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc37
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc48
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc12
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h7
-rw-r--r--tensorflow/compiler/xla/xla.proto3
14 files changed, 121 insertions, 39 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index f037663e3f..70ae95bf47 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -43,7 +43,7 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
#ifdef INTEL_MKL
flags->set_xla_cpu_use_mkl_dnn(true);
#endif // INTEL_MKL
-
+ flags->set_xla_gpu_max_kernel_unroll_factor(1);
// Set cudnn batchnorm off by default; it does not provide a performance win
// on average.
flags->set_xla_gpu_use_cudnn_batchnorm(false);
@@ -224,6 +224,11 @@ void AllocateFlags() {
flag_values->xla_gpu_disable_multi_streaming(),
"If true, multi-streaming in the GPU backend is disabled."),
tensorflow::Flag(
+ "xla_gpu_max_kernel_unroll_factor",
+ int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
+ flag_values->xla_gpu_max_kernel_unroll_factor(),
+ "Specify the maximum kernel unroll factor for the GPU backend."),
+ tensorflow::Flag(
"xla_dump_optimized_hlo_proto_to",
flag_values->mutable_xla_dump_optimized_hlo_proto_to(),
"Dump Hlo after all hlo passes are executed as proto binary into "
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
index 1e439cde11..54af40506d 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -29,7 +29,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
: LoopEmitter(target_element_generator, target_array, ir_builder),
dynamic_loop_bounds_(dynamic_loop_bounds) {}
-llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
+std::vector<llvm_ir::IrArray::Index>
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) {
CHECK(!ShapeUtil::IsTuple(shape_));
CHECK(!ShapeUtil::IsScalar(shape_));
@@ -69,7 +70,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
CHECK(exit_bb_ != nullptr);
- return array_index;
+ return {array_index};
}
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index ce92e36a94..755715634a 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -60,7 +60,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
~ParallelLoopEmitter() override = default;
- llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock(
+ std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index d29cc21ab1..26e497762f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -536,7 +536,27 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
+
+ int max_unroll_factor = fusion->GetModule()
+ ->config()
+ .debug_options()
+ .xla_gpu_max_kernel_unroll_factor();
+
+ // Find the largest possible power of two to unroll by.
+ // TODO(kramerb): Make this smarter.
+ int unroll_factor = 1;
+ if (!fusion->IsMultiOutputFusion()) {
+ CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
+ int64 num_elements = ShapeUtil::ElementsIn(fusion->shape());
+ for (int i = max_unroll_factor; i > 1; i /= 2) {
+ if (num_elements % i == 0) {
+ unroll_factor = i;
+ break;
+ }
+ }
+ }
+
+ thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor));
return IrEmitter::HandleFusion(fusion);
}
@@ -2021,7 +2041,7 @@ Status IrEmitterUnnested::HandleGather(HloInstruction* gather) {
}
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
- const HloInstruction* inst) {
+ const HloInstruction* inst, int unroll_factor) {
const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment();
@@ -2113,7 +2133,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
}
return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
- inst);
+ inst, unroll_factor);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
@@ -2485,21 +2505,28 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
+ int unroll_factor = thunk->unroll_factor();
VLOG(3) << bindings_.ToString();
const Shape& element_shape = hlo.IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo.shape(), {0})
: hlo.shape();
+ VLOG(3) << "EmitTargetElementLoopInThunk "
+ << ShapeUtil::HumanStringWithLayout(hlo.shape())
+ << " for unroll_factor " << unroll_factor;
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- element_shape, ir_emitter_context_->device_description());
+ element_shape, ir_emitter_context_->device_description(), unroll_factor);
UpdateLaunchDimensions(launch_dimensions, thunk,
ir_emitter_context_->llvm_module());
if (!hlo.IsMultiOutputFusion()) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
- launch_dimensions, &ir_builder_)
+ launch_dimensions, &ir_builder_, unroll_factor)
.EmitLoop(IrName(&hlo));
}
+ CHECK_EQ(unroll_factor, 1)
+ << "multi-output fusion does not support unrolling";
+
// For multiple outputs fusion, we need to emit each operand and the root.
std::vector<llvm_ir::IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 66c62e2d2d..b842f480c6 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -150,8 +150,10 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
- // Thunk object.
- std::unique_ptr<KernelThunk> BuildKernelThunk(const HloInstruction* inst);
+ // Thunk object. The kernel implementation will be unrolled if unroll_factor
+ // is greater than one.
+ std::unique_ptr<KernelThunk> BuildKernelThunk(const HloInstruction* inst,
+ int unroll_factor = 1);
// Returns a FftThunk that calls cuFFT to implement `inst`.
std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index c20a781a33..c24dc1457f 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -30,10 +30,12 @@ namespace gpu {
KernelThunk::KernelThunk(
tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
- const string& kernel_name, const HloInstruction* hlo_instruction)
+ const string& kernel_name, const HloInstruction* hlo_instruction,
+ int unroll_factor)
: Thunk(Kind::kKernel, hlo_instruction),
args_(args.begin(), args.end()),
- kernel_name_(kernel_name) {}
+ kernel_name_(kernel_name),
+ unroll_factor_(unroll_factor) {}
tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) {
tensorflow::mutex_lock lock(mutex_);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index 9ae455e2fc..df8971b083 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -47,12 +47,14 @@ class KernelThunk : public Thunk {
//
// `hlo_instruction` is as in Thunk. Other arguments are as the class members.
KernelThunk(tensorflow::gtl::ArraySlice<const BufferAllocation*> args,
- const string& kernel_name, const HloInstruction* hlo_instruction);
+ const string& kernel_name, const HloInstruction* hlo_instruction,
+ int unroll_factor);
KernelThunk(const KernelThunk&) = delete;
KernelThunk& operator=(const KernelThunk&) = delete;
~KernelThunk() override = default;
const string& kernel_name() const { return kernel_name_; }
+ int unroll_factor() const { return unroll_factor_; }
void SetLaunchDimensions(const LaunchDimensions& launch_dims);
tensorflow::Status Initialize(const GpuExecutable& executable) override;
@@ -69,6 +71,10 @@ class KernelThunk : public Thunk {
// Entry kernel name for the computation.
const string kernel_name_;
+ // The number of times this kernel should be unrolled. This works as a
+ // multiplier on the number of elements produced by a GPU thread.
+ const int unroll_factor_;
+
// The thread and block dimension used to launch the kernel.
// Will be set by IrEmitterUnnested.
LaunchDimensions launch_dimensions_;
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index 388dcc008b..d8c07dc311 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -32,25 +32,32 @@ namespace gpu {
ParallelLoopEmitter::ParallelLoopEmitter(
BodyEmitter body_emitter, const Shape& shape,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor)
: LoopEmitter(body_emitter, shape, ir_builder),
- launch_dimensions_(launch_dimensions) {}
+ launch_dimensions_(launch_dimensions),
+ unroll_factor_(unroll_factor) {}
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor)
: LoopEmitter(target_element_generator, target_arrays, ir_builder),
- launch_dimensions_(launch_dimensions) {}
+ launch_dimensions_(launch_dimensions),
+ unroll_factor_(unroll_factor) {}
ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
const llvm_ir::IrArray& target_array,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor)
: LoopEmitter(target_element_generator, target_array, ir_builder),
- launch_dimensions_(launch_dimensions) {}
+ launch_dimensions_(launch_dimensions),
+ unroll_factor_(unroll_factor) {}
-llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
+std::vector<llvm_ir::IrArray::Index>
+ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) {
// Emit the following code in LLVM IR:
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
@@ -63,6 +70,9 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
// "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x"
//
// %nctaid.x is currently specified as 2147483647.
+ VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_;
+ std::vector<llvm_ir::IrArray::Index> array_indices;
+
llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(),
@@ -81,7 +91,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(),
"thread_id");
- llvm::Value* linear_index = ir_builder_->CreateAdd(
+ llvm::Value* linear_index_base = ir_builder_->CreateAdd(
ir_builder_->CreateMul(
block_id,
ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "",
@@ -99,15 +109,30 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::assume,
{ir_builder_->CreateICmpULT(
- linear_index,
+ linear_index_base,
ir_builder_->getInt64(launch_dimensions_.threads_per_block() *
launch_dimensions_.block_count()),
"linear_index_in_range")},
{}, ir_builder_);
+ if (unroll_factor_ > 1) {
+ linear_index_base = ir_builder_->CreateMul(
+ linear_index_base, ir_builder_->getInt64(unroll_factor_),
+ "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
+ }
+
+ array_indices.emplace_back(linear_index_base, shape_, ir_builder_);
+ for (int i = 1; i < unroll_factor_; ++i) {
+ llvm::Value* linear_index = ir_builder_->CreateAdd(
+ linear_index_base, ir_builder_->getInt64(i), "linear_index",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ array_indices.emplace_back(linear_index, shape_, ir_builder_);
+ }
+
auto if_in_bounds = llvm_ir::EmitIfThenElse(
ir_builder_->CreateICmpULT(
- linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))),
+ linear_index_base,
+ ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))),
llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false);
// Set exit_bb_ to the exit block of the if structure.
@@ -116,7 +141,8 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
// Set IR builder insertion point to the body of the if structure.
llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
- return llvm_ir::IrArray::Index(linear_index, shape_, ir_builder_);
+
+ return array_indices;
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index 8ed63a854a..25318b3bed 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -34,13 +34,13 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
// The meanings of other parameters are the same as LoopEmitter.
ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape,
const LaunchDimensions& launch_dimensions,
- llvm::IRBuilder<>* ir_builder);
+ llvm::IRBuilder<>* ir_builder, int unroll_factor = 1);
// Constructs a ParallelLoopEmitter from an element generator that generates
// each element of the given target array.
ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
const llvm_ir::IrArray& target_array,
const LaunchDimensions& launch_dimensions,
- llvm::IRBuilder<>* ir_builder);
+ llvm::IRBuilder<>* ir_builder, int unroll_factor = 1);
// Constructs a loop emitter for a loop that generates on element of each of N
// arrays on each iteration.
@@ -50,18 +50,20 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
- const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder);
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder,
+ int unroll_factor = 1);
ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
~ParallelLoopEmitter() override = default;
- llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock(
+ std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) override;
private:
// The thread and block dimension to parallelize the loop on.
const LaunchDimensions launch_dimensions_;
+ const int unroll_factor_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
index 6cf280df05..5283d51cd1 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -44,12 +44,16 @@ std::ostream& operator<<(std::ostream& out,
// Calculates the launch dimensions used to invoke `hlo`.
LaunchDimensions CalculateLaunchDimensions(
- const Shape& shape, const se::DeviceDescription& device_desc) {
+ const Shape& shape, const se::DeviceDescription& device_desc,
+ int unroll_factor) {
int64 num_elements = ShapeUtil::ElementsIn(shape);
if (num_elements <= 1) {
return LaunchDimensions();
}
+ CHECK_EQ(num_elements % unroll_factor, 0);
+ num_elements = num_elements / unroll_factor;
+
// Since we don't do any inter-warp communication, we're free to choose any
// block size we want, subject to hardware constraints. We choose the
// smallest block size that allows the GPU to reach full occupancy (assuming
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
index 0bf463a6ef..42d2d2af2e 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
@@ -58,7 +58,8 @@ std::ostream& operator<<(std::ostream& out,
LaunchDimensions CalculateLaunchDimensions(
const Shape& shape,
- const perftools::gputools::DeviceDescription& device_desc);
+ const perftools::gputools::DeviceDescription& device_desc,
+ int unroll_factor = 1);
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index b6b918ec78..3978acc132 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -88,12 +88,12 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
}
}
-IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock(
+std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name) {
if (ShapeUtil::IsScalar(shape_)) {
// No loop needed, so set exit_bb_ to nullptr.
exit_bb_ = nullptr;
- return IrArray::Index();
+ return {IrArray::Index()};
}
// Create loop nest with one for-loop for each dimension of the target shape.
@@ -121,12 +121,14 @@ IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock(
exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
CHECK_NOTNULL(exit_bb_);
- return array_index;
+ return {array_index};
}
tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) {
- IrArray::Index array_index = EmitIndexAndSetExitBasicBlock(loop_name);
- TF_RETURN_IF_ERROR(body_emitter_(array_index));
+ for (const IrArray::Index& array_index :
+ EmitIndexAndSetExitBasicBlock(loop_name)) {
+ TF_RETURN_IF_ERROR(body_emitter_(array_index));
+ }
// Set the insertion point of ir_builder_ to the loop exit, so that
// code emitted for later instructions will be correctly placed.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index 0fc528439a..9ff497aecd 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -63,11 +63,12 @@ class LoopEmitter {
// Emits a loop nest (with a yet-to-be-filled loop body) that iterates through
// every element in the given shape. Returns the multi-dimensional index that
- // specifies the element.
- IrArray::Index EmitIndexAndSetExitBasicBlock() {
+ // specifies the element, will return multiple indices if the loop is
+ // unrolled.
+ std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock() {
return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"");
}
- virtual IrArray::Index EmitIndexAndSetExitBasicBlock(
+ virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
tensorflow::StringPiece loop_name);
// Emits a complete loop nest for every element in the given shape.
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index b4cbdf3773..f619b8dc24 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -192,6 +192,9 @@ message DebugOptions {
// Generate calls to MKL-DNN in the CPU backend.
bool xla_cpu_use_mkl_dnn = 97;
+ // Maximum kernel unroll factor for the GPU backend.
+ int32 xla_gpu_max_kernel_unroll_factor = 98;
+
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;