aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-07-04 03:30:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-04 03:33:12 -0700
commit51061e1eae6fdbe1f7af78e1c1952cbce318099a (patch)
tree66b0f4f1b48e92e1de665502e4e3a9d21542af03 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parentcead016fe01e8f0ff732a080203b29990d475c98 (diff)
Profile SequentialThunks if they represent one HloInstruction.
SequentialThunks are used in two different ways: sometimes as a sequence of individual thunks for different HloInstructions, and sometimes for one HloInstruction which consists of several thunks. For the latter, we want to measure the total time taken by the HloInstruction. Previously, we would instead measure the time of the last thunk from the SequentialThunk. PiperOrigin-RevId: 203258617
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc67
1 files changed, 45 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1d0b6597eb..945ad371e8 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -355,7 +355,8 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
unroll_factor = ComputeMaxUnrollFactor(hlo);
}
- thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor));
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ hlo, /*implements_whole_instruction=*/true, unroll_factor));
return IrEmitter::DefaultAction(hlo);
}
@@ -369,7 +370,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(dot));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
return IrEmitter::HandleDot(dot);
}
@@ -379,7 +381,8 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
}
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
- thunk_sequence_->emplace_back(BuildKernelThunk(convolution));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
return IrEmitter::HandleConvolution(convolution);
}
@@ -586,7 +589,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
}
CHECK(first_reduce != nullptr);
- thunks.push_back(BuildKernelThunk(fusion));
+ thunks.push_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
std::vector<llvm_ir::IrArray> parameter_arrays;
@@ -660,7 +664,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// touching the un-updated elements.
// Set up kernel thunk and fused ir emitter.
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/true));
std::vector<llvm_ir::IrArray> operand_arrays;
for (HloInstruction* operand : fusion->operands()) {
operand_arrays.push_back(GetIrArray(*operand, *fusion));
@@ -695,7 +700,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
int unroll_factor = ComputeMaxUnrollFactor(fusion);
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor));
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ fusion, /*implements_whole_instruction=*/true, unroll_factor));
return IrEmitter::HandleFusion(fusion);
}
@@ -1013,7 +1019,8 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
if (is_transpose_021 &&
reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
- thunk_sequence_->emplace_back(BuildKernelThunk(copy));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(copy, /*implements_whole_instruction=*/true));
VLOG(3) << "Emitting tiled 0-2-1 transposition";
constexpr int64 tile_size = 32;
constexpr int64 num_rows = 8;
@@ -1982,7 +1989,8 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
BuildInitializerThunk(reduce));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(reduce));
+ thunks.push_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
@@ -1998,7 +2006,8 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
dimensions_to_reduce, {reducer}, {{}}, {});
}
- thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/true));
return IrEmitter::HandleReduce(reduce);
}
@@ -2027,7 +2036,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(tuple));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
return IrEmitter::HandleTuple(tuple);
}
@@ -2052,7 +2062,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
BuildInitializerThunk(select_and_scatter));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(select_and_scatter));
+ thunks.push_back(BuildKernelThunk(select_and_scatter,
+ /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
@@ -2260,17 +2271,20 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
}
Status IrEmitterUnnested::HandleRng(HloInstruction* random) {
- thunk_sequence_->push_back(BuildKernelThunk(random));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(random, /*implements_whole_instruction=*/true));
return IrEmitter::HandleRng(random);
}
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
- thunk_sequence_->push_back(BuildKernelThunk(select));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(select, /*implements_whole_instruction=*/true));
return IrEmitter::HandleSelect(select);
}
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
- thunk_sequence_->push_back(BuildKernelThunk(tuple_select));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
return IrEmitter::HandleTupleSelect(tuple_select);
}
@@ -2309,12 +2323,12 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
- /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs));
+ /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
}
// Output a tuple of the buffers above.
thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
- GetAllocationSlice(*crs), crs));
+ GetAllocationSlice(*crs), nullptr));
thunk_sequence_->push_back(
MakeUnique<SequentialThunk>(std::move(thunks), crs));
return Status::OK();
@@ -2448,7 +2462,8 @@ GetHloBufferSlices(const HloInstruction* hlo,
}
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
- const HloInstruction* inst, int unroll_factor) {
+ const HloInstruction* inst, bool implements_whole_instruction,
+ int unroll_factor) {
const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment();
@@ -2540,7 +2555,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
}
return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
- inst, unroll_factor);
+ implements_whole_instruction ? inst : nullptr,
+ unroll_factor);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
@@ -2697,6 +2713,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
init_value = hlo->operand(init_value->parameter_number());
}
+ // Initializer thunks don't implement a whole instruction, and we want to
+ // profile the whole instruction instead of the individual thunks it consists
+ // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
+ // generate below.
+ //
// In the common case, the initializer is a constant. In this case, emit a
// device-memset call if we can. Currently StreamExecutor only supports
// zeroing and 32-bit memsets.
@@ -2710,7 +2731,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ArraySlice<uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), hlo)};
+ return {
+ MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2728,7 +2750,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
return {MakeUnique<Memset32BitValueThunk>(
- pattern32, GetAllocationSlice(*hlo, index), hlo)};
+ pattern32, GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@@ -2739,12 +2761,13 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
return {MakeUnique<Memset32BitValueThunk>(
- word, GetAllocationSlice(*hlo, index), hlo)};
+ word, GetAllocationSlice(*hlo, index), nullptr)};
}
}
// Otherwise fall back to our slow initializer code.
- std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(hlo);
+ std::unique_ptr<KernelThunk> kernel_thunk =
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
LaunchDimensions launch_dimensions =
CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
ir_emitter_context_->device_description());