aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-10-09 17:19:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 17:23:55 -0700
commit5be479930d3dcfa3edb863703b1d73b89d45f03c (patch)
tree4d89676c2a1b6ddf0cc1da3873536d6471b50321 /tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
parent9bd459e4ceba14f9bb1af98d52a109325de952e8 (diff)
[XLA:GPU] Use CudnnConvKind in more places.
No functional change. PiperOrigin-RevId: 216451881
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc99
1 files changed, 50 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 89dd1bb272..a809c22b33 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -312,11 +312,12 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
conv->backend_config<CudnnConvBackendConfig>());
- const auto& target = conv->custom_call_target();
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv));
const auto& lhs_shape = conv->operand(0)->shape();
const auto& rhs_shape = conv->operand(1)->shape();
const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+ params.kind = kind;
params.window = &conv->window();
params.dnums = &conv->convolution_dimension_numbers();
params.feature_group_count = conv->feature_group_count();
@@ -324,55 +325,55 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
backend_config.algorithm(), backend_config.tensor_ops_enabled()));
params.conv_result_scale = backend_config.conv_result_scale();
- if (target == kCudnnConvForwardCallTarget) {
- params.kind = CudnnConvKind::kForward;
- params.input_shape = &lhs_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &conv_result_shape;
- params.input_buf = operand_buffers[0];
- params.filter_buf = operand_buffers[1];
- params.output_buf = result_buffer;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- params.kind = CudnnConvKind::kBackwardInput;
- params.input_shape = &conv_result_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &lhs_shape;
- params.input_buf = result_buffer;
- params.filter_buf = operand_buffers[1];
- params.output_buf = operand_buffers[0];
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- params.kind = CudnnConvKind::kBackwardFilter;
- params.input_shape = &lhs_shape;
- params.filter_shape = &conv_result_shape;
- params.output_shape = &rhs_shape;
- params.input_buf = operand_buffers[0];
- params.filter_buf = result_buffer;
- params.output_buf = operand_buffers[1];
- } else if (target == kCudnnConvBiasActivationForwardCallTarget) {
- params.kind = CudnnConvKind::kForwardActivation;
- params.input_shape = &lhs_shape;
- params.filter_shape = &rhs_shape;
- params.output_shape = &conv_result_shape;
- params.fusion.emplace();
- auto& fusion = *params.fusion;
- if (backend_config.activation_mode() <
- static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
- fusion.mode = static_cast<se::dnn::ActivationMode>(
- backend_config.activation_mode());
- } else {
- return InternalError("Bad activation mode: %s",
- backend_config.ShortDebugString());
- }
- fusion.side_input_scale = backend_config.side_input_scale();
- params.input_buf = operand_buffers[0];
- params.filter_buf = operand_buffers[1];
- params.output_buf = result_buffer;
- params.fusion->bias_buf = operand_buffers[2];
- if (operand_buffers.size() >= 4) {
- params.fusion->side_input_buf = operand_buffers[3];
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ break;
+ case CudnnConvKind::kForwardActivation: {
+ params.kind = CudnnConvKind::kForwardActivation;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.fusion.emplace();
+ auto& fusion = *params.fusion;
+ if (backend_config.activation_mode() <
+ static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
+ fusion.mode = static_cast<se::dnn::ActivationMode>(
+ backend_config.activation_mode());
+ } else {
+ return InternalError("Bad activation mode: %s",
+ backend_config.ShortDebugString());
+ }
+ fusion.side_input_scale = backend_config.side_input_scale();
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ params.fusion->bias_buf = operand_buffers[2];
+ if (operand_buffers.size() >= 4) {
+ params.fusion->side_input_buf = operand_buffers[3];
+ }
}
- } else {
- return InternalError("Unexpected custom call target: %s", target);
}
return params;
}