aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/conditional_thunk.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc21
1 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index 77a48965e0..5780e0af40 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,8 +33,11 @@ ConditionalThunk::ConditionalThunk(
predicate_buffer_index_(predicate_buffer_index),
true_operand_buffer_index_(true_operand_buffer_index),
false_operand_buffer_index_(false_operand_buffer_index),
- true_thunk_(std::move(true_thunk_sequence), hlo),
- false_thunk_(std::move(false_thunk_sequence), hlo) {}
+ // Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_
+ // constructors because these SequentialThunks are logically "part of"
+ // this ConditionalThunk, and shouldn't be profiled separately from it.
+ true_thunk_(std::move(true_thunk_sequence), nullptr),
+ false_thunk_(std::move(false_thunk_sequence), nullptr) {}
Status ConditionalThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@@ -43,7 +47,9 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
}
Status ConditionalThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
// Copy the predicate value from device.
bool predicate;
se::DeviceMemoryBase predicate_address =
@@ -59,10 +65,15 @@ Status ConditionalThunk::ExecuteOnStream(
// Execute the true or the false computation depending on the value of the
// predicate.
if (predicate) {
- TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ profiler->StartHloComputation();
+ TF_RETURN_IF_ERROR(
+ true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->true_computation());
} else {
+ profiler->StartHloComputation();
TF_RETURN_IF_ERROR(
- false_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->false_computation());
}
return Status::OK();