aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-03 04:52:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 04:57:16 -0700
commit0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch)
treecc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff)
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the ConvolutionFeatureGroupConverter pass and can instead set the group_count parameter on the cudnn custom calls. PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 78f61a4987..389a98facb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -489,8 +489,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
@@ -503,8 +503,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/lhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
@@ -517,8 +517,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/conv_result_shape,
/*output_shape=*/rhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();