From 5be479930d3dcfa3edb863703b1d73b89d45f03c Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 9 Oct 2018 17:19:24 -0700 Subject: [XLA:GPU] Use CudnnConvKind in more places. No functional change. PiperOrigin-RevId: 216451881 --- .../xla/service/gpu/cudnn_convolution_runner.cc | 99 +++++++++++----------- 1 file changed, 50 insertions(+), 49 deletions(-) (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc') diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 89dd1bb272..a809c22b33 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -312,11 +312,12 @@ StatusOr GetCudnnConvParams( TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, conv->backend_config()); - const auto& target = conv->custom_call_target(); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv)); const auto& lhs_shape = conv->operand(0)->shape(); const auto& rhs_shape = conv->operand(1)->shape(); const auto& conv_result_shape = conv->shape().tuple_shapes(0); + params.kind = kind; params.window = &conv->window(); params.dnums = &conv->convolution_dimension_numbers(); params.feature_group_count = conv->feature_group_count(); @@ -324,55 +325,55 @@ StatusOr GetCudnnConvParams( backend_config.algorithm(), backend_config.tensor_ops_enabled())); params.conv_result_scale = backend_config.conv_result_scale(); - if (target == kCudnnConvForwardCallTarget) { - params.kind = CudnnConvKind::kForward; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - } else if (target == kCudnnConvBackwardInputCallTarget) { - params.kind = CudnnConvKind::kBackwardInput; - params.input_shape = &conv_result_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &lhs_shape; - params.input_buf = result_buffer; - params.filter_buf = operand_buffers[1]; - params.output_buf = operand_buffers[0]; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - params.kind = CudnnConvKind::kBackwardFilter; - params.input_shape = &lhs_shape; - params.filter_shape = &conv_result_shape; - params.output_shape = &rhs_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = result_buffer; - params.output_buf = operand_buffers[1]; - } else if (target == kCudnnConvBiasActivationForwardCallTarget) { - params.kind = CudnnConvKind::kForwardActivation; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.fusion.emplace(); - auto& fusion = *params.fusion; - if (backend_config.activation_mode() < - static_cast(se::dnn::ActivationMode::kNumActivationModes)) { - fusion.mode = static_cast( - backend_config.activation_mode()); - } else { - return InternalError("Bad activation mode: %s", - backend_config.ShortDebugString()); - } - fusion.side_input_scale = backend_config.side_input_scale(); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - params.fusion->bias_buf = operand_buffers[2]; - if (operand_buffers.size() >= 4) { - params.fusion->side_input_buf = operand_buffers[3]; + switch (kind) { + case CudnnConvKind::kForward: + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; + case CudnnConvKind::kBackwardInput: + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + case CudnnConvKind::kForwardActivation: { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } } - } else { - return InternalError("Unexpected custom call target: %s", target); } return params; } -- cgit v1.2.3