aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc99
1 files changed, 50 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 89dd1bb272..a809c22b33 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -312,11 +312,12 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
conv->backend_config<CudnnConvBackendConfig>());
- const auto& target = conv->custom_call_target();
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv));
const auto& lhs_shape = conv->operand(0)->shape();
const auto& rhs_shape = conv->operand(1)->shape();
const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+ params.kind = kind;
params.window = &conv->window();
params.dnums = &conv->convolution_dimension_numbers();
params.feature_group_count = conv->feature_group_count();
@@ -324,55 +325,55 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
backend_config.algorithm(), backend_config.tensor_ops_enabled()));
params.conv_result_scale = backend_config.conv_result_scale();
- if (target == kCudnnConvForwardCallTarget) {
- params.kind = CudnnConvKind::kForward;
- params.input_shape = &lhs_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &conv_result_shape;
- params.input_buf = operand_buffers[0];
- params.filter_buf = operand_buffers[1];
- params.output_buf = result_buffer;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- params.kind = CudnnConvKind::kBackwardInput;
- params.input_shape = &conv_result_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &lhs_shape;
- params.input_buf = result_buffer;
- params.filter_buf = operand_buffers[1];
- params.output_buf = operand_buffers[0];
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- params.kind = CudnnConvKind::kBackwardFilter;
- params.input_shape = &lhs_shape;
- params.filter_shape = &conv_result_shape;
- params.output_shape = &rhs_shape;
- params.input_buf = operand_buffers[0];
- params.filter_buf = result_buffer;
- params.output_buf = operand_buffers[1];
- } else if (target == kCudnnConvBiasActivationForwardCallTarget) {
- params.kind = CudnnConvKind::kForwardActivation;
- params.input_shape = &lhs_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &conv_result_shape;
- params.fusion.emplace();
- auto& fusion = *params.fusion;
- if (backend_config.activation_mode() <
- static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
- fusion.mode = static_cast<se::dnn::ActivationMode>(
- backend_config.activation_mode());
- } else {
- return InternalError("Bad activation mode: %s",
- backend_config.ShortDebugString());
- }
- fusion.side_input_scale = backend_config.side_input_scale();
- params.input_buf = operand_buffers[0];
- params.filter_buf = operand_buffers[1];
- params.output_buf = result_buffer;
- params.fusion->bias_buf = operand_buffers[2];
- if (operand_buffers.size() >= 4) {
- params.fusion->side_input_buf = operand_buffers[3];
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ break;
+ case CudnnConvKind::kForwardActivation: {
+ params.kind = CudnnConvKind::kForwardActivation;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.fusion.emplace();
+ auto& fusion = *params.fusion;
+ if (backend_config.activation_mode() <
+ static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
+ fusion.mode = static_cast<se::dnn::ActivationMode>(
+ backend_config.activation_mode());
+ } else {
+ return InternalError("Bad activation mode: %s",
+ backend_config.ShortDebugString());
+ }
+ fusion.side_input_scale = backend_config.side_input_scale();
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ params.fusion->bias_buf = operand_buffers[2];
+ if (operand_buffers.size() >= 4) {
+ params.fusion->side_input_buf = operand_buffers[3];
+ }
}
- } else {
- return InternalError("Unexpected custom call target: %s", target);
}
return params;
}