aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
index 68099fd638..7b172812c3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@@ -99,13 +100,15 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
}
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
std::tie(operand_desc, scale_offset_desc) =
MakeDescriptors(hlo_instruction()->shape(), feature_index_);
se::DeviceMemory<float> output(buffer_allocations.GetDeviceAddress(output_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationForward(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
@@ -123,6 +126,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
/*is_training=*/false, //
/*var_to_inv_var=*/nullptr, //
/*inv_var_to_var=*/nullptr);
+
if (!stream->ok()) {
return InternalError("BatchNormalizationForward call failed.");
}
@@ -158,7 +162,8 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
}
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
@@ -175,6 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
buffer_allocations.GetDeviceAddress(output_inv_stddev_));
se::DeviceMemory<float> null_device_ptr(nullptr);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationForward(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
@@ -240,7 +246,8 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
}
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
@@ -257,6 +264,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
se::DeviceMemory<float> output_grad_offset(
buffer_allocations.GetDeviceAddress(output_grad_offset_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationBackward(
se::DeviceMemory<float>(
buffer_allocations.GetDeviceAddress(grad_output_)),