aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/while_thunk.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/while_thunk.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc32
1 files changed, 23 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index 30b9640c4c..1315a4183a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -29,10 +30,14 @@ WhileThunk::WhileThunk(
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
condition_result_buffer_index_(condition_result_buffer_index),
+ // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
+ // and body_thunk_sequence_ constructors because these SequentialThunks
+ // are logically "part of" this WhileThunk, and shouldn't be profiled
+ // separately from it.
condition_thunk_sequence_(MakeUnique<SequentialThunk>(
- std::move(*condition_thunk_sequence), hlo)),
- body_thunk_sequence_(
- MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+ std::move(*condition_thunk_sequence), nullptr)),
+ body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ std::move(*body_thunk_sequence), nullptr)) {}
Status WhileThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@@ -43,14 +48,18 @@ Status WhileThunk::Initialize(const GpuExecutable& executable,
}
Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase condition_result_data =
buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
while (true) {
// Invoke thunk sequence for while 'condition' computation.
- TF_RETURN_IF_ERROR(
- condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ profiler->StartHloComputation();
+ TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
+ buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_condition());
// Copy the result of condition computation and break the loop if 'false'.
bool condition_result;
@@ -66,9 +75,14 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
break;
}
- // Invoke thunk sequence for while 'body' computation.
- TF_RETURN_IF_ERROR(
- body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ // We measure the time of one execution of the while body computation. The
+ // while body may be executed more than once, the last measurement "wins".
+ profiler->StartHloComputation();
+ // Invoke thunk sequence for while 'body' computation, and pass on
+ // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
+ stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();
}