aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/for_thunk.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/for_thunk.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc18
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index b36539e0cb..b3a3c5dcb4 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -27,8 +28,11 @@ ForThunk::ForThunk(const int64 loop_limit,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
loop_limit_(loop_limit),
- body_thunk_sequence_(
- MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+ body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ // Pass nullptr as the HloInstruction* to the body_thunk_sequence_
+ // constructor because this SequentialThunk is logically "part of"
+ // this ForThunk, and shouldn't be profiled separately from it.
+ std::move(*body_thunk_sequence), nullptr)) {}
Status ForThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@@ -37,11 +41,15 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
}
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (int64 i = 0; i < loop_limit_; ++i) {
+ profiler->StartHloComputation();
// Invoke loop body thunk sequence.
- TF_RETURN_IF_ERROR(
- body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
+ stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();
}