aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc21
1 files changed, 17 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 0b1ee2dc33..9bf721ecd2 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -59,6 +59,11 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ // TODO(b/31709653): Figure out if we can use grouped convolutions also on
+ // backward filter.
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
// Step 1: match the instruction pattern without considering the paddings and
// dimension numbers just yet. We may need some generic pattern matcher
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
@@ -218,6 +223,12 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ // TODO(b/31709653): Figure out if we can use grouped convolutions also on
+ // backward input.
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
+
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
@@ -425,7 +436,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
if (match) {
return CreateCudnnConvBackwardFilter(
conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums);
+ window, dnums, conv->feature_group_count());
}
std::tie(match, window, dnums) = MatchBackwardInput(conv);
@@ -435,15 +446,17 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
HloInstruction* rhs = reverse->mutable_operand(0);
- return CreateCudnnConvBackwardInput(
- conv->shape(), conv->mutable_operand(0), rhs, window, dnums);
+ return CreateCudnnConvBackwardInput(conv->shape(),
+ conv->mutable_operand(0), rhs, window,
+ dnums, conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers());
+ conv->convolution_dimension_numbers(),
+ conv->feature_group_count());
}
return nullptr;