diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc | 167 |
1 files changed, 95 insertions, 72 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 9bf721ecd2..228379a248 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include <cstdlib> #include <numeric> #include <vector> @@ -59,8 +60,6 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward filter. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -218,13 +217,16 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter( // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( - HloInstruction* conv) { +std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*> +MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward input. + // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -232,51 +234,38 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); - - // Match the reverse of the filter. ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); - if (reverse_filter->opcode() == HloOpcode::kReverse) { - if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || - !std::is_permutation(kernel_spatial_dims.begin(), - kernel_spatial_dims.end(), - reverse_filter->dimensions().begin())) { - VLOG(1) - << "Backward input convolution should reverse all kernel dimensions."; - return no_match_result; - } - } else if (reverse_filter->IsConstant()) { - // If the filter is a constant, we're willing to pattern-match to a - // backwards-input conv, on the theory that - // - // a) reversing a constant is free, and - // b) even if the user specified this filter as reverse(constant), we would - // long ago have constant-folded away the reverse. - // - // If the constant has any other uses, reversing it isn't entirely free, - // since we'd now have two constants to keep in memory. But hopefully it's - // free enough. - // - // TODO(jlebar): Should we do this even if the filter is not a constant? - // Reversing a non-constant filter is probably cheaper than padding the - // input! - - // Nothing to do, just fall through. - } else { - // Possibly 1x1 filter. - for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { - if (conv->window().dimensions(i).size() != 1) { - VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " - << reverse_filter->ToString(); - return no_match_result; - } - } - if (!window_util::HasBaseDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward input convolution."; - return no_match_result; - } + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dilated conv with a 1x1 or " + "constant filter."; + return no_match_result; } // Match padding and dilation of the forward convolution. @@ -401,26 +390,64 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( } } - // OK, it's a match! Canonicalize the conv's filter so that it's a reverse. - // This simplifies things for our caller, and algebraic-simplifier will later - // remove any unnecessary reverses. - if (reverse_filter->opcode() != HloOpcode::kReverse) { + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { // Create a double-reverse, which is a nop. HloComputation* c = conv->parent(); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } - dnums.set_kernel_input_feature_dimension( - conv->convolution_dimension_numbers().kernel_output_feature_dimension()); - dnums.set_kernel_output_feature_dimension( - conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); } // Tries to rewrite a single convolution into a call to cudnn. @@ -431,6 +458,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { @@ -439,13 +467,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { window, dnums, conv->feature_group_count()); } - std::tie(match, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - // Backward input conv subsumes the conv plus the reverse in operand 1. - HloInstruction* reverse = conv->mutable_operand(1); - CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); - HloInstruction* rhs = reverse->mutable_operand(0); - return CreateCudnnConvBackwardInput(conv->shape(), conv->mutable_operand(0), rhs, window, dnums, conv->feature_group_count()); |