aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc167
1 files changed, 95 insertions, 72 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 9bf721ecd2..228379a248 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include <cstdlib>
#include <numeric>
#include <vector>
@@ -59,8 +60,6 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward filter.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -218,13 +217,16 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
// Try to match a backward input pattern that contains "conv".
// Precondition: "conv" is a kConvolution.
-std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
- HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
+MatchBackwardInput(HloInstruction* conv) {
const auto no_match_result =
- std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
- // TODO(b/31709653): Figure out if we can use grouped convolutions also on
- // backward input.
+ // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also
+ // for the backward input convolution, but at least for now with version 7.1.4
+ // it is slower. This needs to be re-evaluated for future cuDNN versions.
+ // Note that we already have the necessary code down below, the only thing to
+ // enable it is to remove the following early return.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
@@ -232,51 +234,38 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
-
- // Match the reverse of the filter.
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
- const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
- if (reverse_filter->opcode() == HloOpcode::kReverse) {
- if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
- !std::is_permutation(kernel_spatial_dims.begin(),
- kernel_spatial_dims.end(),
- reverse_filter->dimensions().begin())) {
- VLOG(1)
- << "Backward input convolution should reverse all kernel dimensions.";
- return no_match_result;
- }
- } else if (reverse_filter->IsConstant()) {
- // If the filter is a constant, we're willing to pattern-match to a
- // backwards-input conv, on the theory that
- //
- // a) reversing a constant is free, and
- // b) even if the user specified this filter as reverse(constant), we would
- // long ago have constant-folded away the reverse.
- //
- // If the constant has any other uses, reversing it isn't entirely free,
- // since we'd now have two constants to keep in memory. But hopefully it's
- // free enough.
- //
- // TODO(jlebar): Should we do this even if the filter is not a constant?
- // Reversing a non-constant filter is probably cheaper than padding the
- // input!
-
- // Nothing to do, just fall through.
- } else {
- // Possibly 1x1 filter.
- for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
- if (conv->window().dimensions(i).size() != 1) {
- VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
- << reverse_filter->ToString();
- return no_match_result;
- }
- }
- if (!window_util::HasBaseDilation(conv->window())) {
- VLOG(1) << conv->ToString()
- << " is a regular forward convolution. No need "
- "to fold it to a backward input convolution.";
- return no_match_result;
- }
+
+ // We pattern-match to a backwards input conv if:
+ //
+ // - all spatial dims of the filter are reversed
+ //
+ // OR
+ //
+ // - filter is 1x1 or a constant AND
+ // - conv has base dilation (otherwise this is just a regular forward conv).
+ //
+ // The final criterion above is just for canonicalization; cudnn seems to run
+ // just as fast if we canonicalize 1x1/constant filters without base dilation
+ // to forward or backward convs. We canonicalize to forward conv because (a)
+ // it's more natural (constant filters usually show up when doing inference,
+ // and having backwards convolutions in inference graphs would be weird), and
+ // (b) cudnn has special fusions for forward conv plus bias and activation,
+ // and we want to pattern-match to that after running this pass.
+ bool is_reversed_filter =
+ reverse_filter->opcode() == HloOpcode::kReverse &&
+ absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+ reverse_filter->dimensions());
+ bool is_1x1_filter =
+ absl::c_all_of(conv->window().dimensions(),
+ [](const WindowDimension& d) { return d.size() == 1; });
+ if (!is_reversed_filter &&
+ !(window_util::HasBaseDilation(conv->window()) &&
+ (reverse_filter->IsConstant() || is_1x1_filter))) {
+ VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+ "kReverse, or it's not a base-dilated conv with a 1x1 or "
+ "constant filter.";
+ return no_match_result;
}
// Match padding and dilation of the forward convolution.
@@ -401,26 +390,64 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
}
}
- // OK, it's a match! Canonicalize the conv's filter so that it's a reverse.
- // This simplifies things for our caller, and algebraic-simplifier will later
- // remove any unnecessary reverses.
- if (reverse_filter->opcode() != HloOpcode::kReverse) {
+ // OK, it's a match! Switch the input feature dimension with the output
+ // feature dimension. This is the way cuDNN expects it to be.
+ dnums.set_kernel_input_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_output_feature_dimension());
+ dnums.set_kernel_output_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_input_feature_dimension());
+
+ // If we matched against a constant, we need to add a reverse op that can be
+ // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+ // unnecessary reverses.
+ if (reverse_filter->opcode() != HloOpcode::kReverse &&
+ reverse_filter->IsConstant()) {
// Create a double-reverse, which is a nop.
HloComputation* c = conv->parent();
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
}
- dnums.set_kernel_input_feature_dimension(
- conv->convolution_dimension_numbers().kernel_output_feature_dimension());
- dnums.set_kernel_output_feature_dimension(
- conv->convolution_dimension_numbers().kernel_input_feature_dimension());
- return std::make_tuple(true, new_window, dnums);
+ // Calculate the 'rhs' that goes into the backward input convolution.
+ HloInstruction* rhs = reverse_filter;
+ // One reverse is subsumed by the cuDNN call.
+ if (rhs->opcode() == HloOpcode::kReverse) {
+ rhs = rhs->mutable_operand(0);
+ }
+ if (conv->feature_group_count() == 1) {
+ return std::make_tuple(true, new_window, dnums, rhs);
+ }
+
+ // Handle grouped convolutions. Because we swapped the input feature dimension
+ // with the output feature dimension, we need to also reshape the kernel so
+ // that the 'feature_group_count' parameter still makes sense. The
+ // 'feature_group_count' parameter essentially specifies how often the
+ // 'kernel_input_feature_dimension' is repeated. So when we swap these
+ // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+ // 'feature_group_count' and multiply the new
+ // 'kernel_output_feature_dimension' by 'feature_group_count'.
+ Shape new_shape = rhs->shape();
+ int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
+ int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
+
+ // In the backward convolution case, the spatial dimensions become the
+ // feature dimensions, and we are guaranteed that the spatial dimensions are
+ // adjacent.
+ CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
+ int64 input_features = new_shape.dimensions(input_feature_dimension);
+ int64 output_features = new_shape.dimensions(output_feature_dimension);
+ new_shape.set_dimensions(input_feature_dimension,
+ input_features / conv->feature_group_count());
+ new_shape.set_dimensions(output_feature_dimension,
+ output_features * conv->feature_group_count());
+ HloComputation* c = conv->parent();
+ rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+ return std::make_tuple(true, new_window, dnums, rhs);
}
// Tries to rewrite a single convolution into a call to cudnn.
@@ -431,6 +458,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
bool match;
Window window;
ConvolutionDimensionNumbers dnums;
+ HloInstruction* rhs;
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
@@ -439,13 +467,8 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
window, dnums, conv->feature_group_count());
}
- std::tie(match, window, dnums) = MatchBackwardInput(conv);
+ std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- // Backward input conv subsumes the conv plus the reverse in operand 1.
- HloInstruction* reverse = conv->mutable_operand(1);
- CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
- HloInstruction* rhs = reverse->mutable_operand(0);
-
return CreateCudnnConvBackwardInput(conv->shape(),
conv->mutable_operand(0), rhs, window,
dnums, conv->feature_group_count());