diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/for_thunk.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/for_thunk.cc | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index b36539e0cb..b3a3c5dcb4 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -27,8 +28,11 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_( - MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {} + body_thunk_sequence_(MakeUnique<SequentialThunk>( + // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ + // constructor because this SequentialThunk is logically "part of" + // this ForThunk, and shouldn't be profiled separately from it. + std::move(*body_thunk_sequence), nullptr)) {} Status ForThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -37,11 +41,15 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (int64 i = 0; i < loop_limit_; ++i) { + profiler->StartHloComputation(); // Invoke loop body thunk sequence. - TF_RETURN_IF_ERROR( - body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, + stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); } |