aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-03 04:52:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 04:57:16 -0700
commit0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch)
treecc76401cd5057d454fc0dc82178c3747a122cb3f /tensorflow/compiler/xla/service/gpu/pad_insertion.cc
parent44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff)
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the ConvolutionFeatureGroupConverter pass and can instead set the group_count parameter on the cudnn custom calls. PiperOrigin-RevId: 211339551
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 98cc21ccac..9d85d746d8 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -166,9 +166,9 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
Shape old_conv_shape = conv->shape().tuple_shapes(0);
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
- new_conv_window,
- conv->convolution_dimension_numbers());
+ auto new_conv = CreateCudnnConvForward(
+ old_conv_shape, new_input, new_kernel, new_conv_window,
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -247,7 +247,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -312,7 +312,7 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.