aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-09-10 17:34:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 17:38:20 -0700
commitbfc1897518063bfa1d62d9a3cfe5e6362c0d09d9 (patch)
treec4952f81902029b58e319025b78f991b6276f053 /tensorflow/compiler/xla/service/gpu
parentfea74706aaa314cc77ec66c2c986365590e8df27 (diff)
[XLA:GPU] Don't canonicalize forward convs with constant filters to backwards conv.
There's no right answer between these two choices, and our benchmarks show no performance difference. But canonicalizing to forward conv makes later pattern-matching passes work properly. PiperOrigin-RevId: 212366534
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc87
1 files changed, 37 insertions, 50 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 4a6a84d87d..3d1266355b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -234,51 +234,38 @@ MatchBackwardInput(HloInstruction* conv) {
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
-
- // Match the reverse of the filter.
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
- const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
- if (reverse_filter->opcode() == HloOpcode::kReverse) {
- if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
- !std::is_permutation(kernel_spatial_dims.begin(),
- kernel_spatial_dims.end(),
- reverse_filter->dimensions().begin())) {
- VLOG(1)
- << "Backward input convolution should reverse all kernel dimensions.";
- return no_match_result;
- }
- } else if (reverse_filter->IsConstant()) {
- // If the filter is a constant, we're willing to pattern-match to a
- // backwards-input conv, on the theory that
- //
- // a) reversing a constant is free, and
- // b) even if the user specified this filter as reverse(constant), we would
- // long ago have constant-folded away the reverse.
- //
- // If the constant has any other uses, reversing it isn't entirely free,
- // since we'd now have two constants to keep in memory. But hopefully it's
- // free enough.
- //
- // TODO(jlebar): Should we do this even if the filter is not a constant?
- // Reversing a non-constant filter is probably cheaper than padding the
- // input!
-
- // Nothing to do, just fall through.
- } else {
- // Possibly 1x1 filter.
- for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
- if (conv->window().dimensions(i).size() != 1) {
- VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
- << reverse_filter->ToString();
- return no_match_result;
- }
- }
- if (!window_util::HasBaseDilation(conv->window())) {
- VLOG(1) << conv->ToString()
- << " is a regular forward convolution. No need "
- "to fold it to a backward input convolution.";
- return no_match_result;
- }
+
+ // We pattern-match to a backwards input conv if:
+ //
+ // - all spatial dims of the filter are reversed
+ //
+ // OR
+ //
+ // - filter is 1x1 or a constant AND
+ // - conv has base dilation (otherwise this is just a regular forward conv).
+ //
+ // The final criterion above is just for canonicalization; cudnn seems to run
+ // just as fast if we canonicalize 1x1/constant filters without base dilation
+ // to forward or backward convs. We canonicalize to forward conv because (a)
+ // it's more natural (constant filters usually show up when doing inference,
+ // and having backwards convolutions in inference graphs would be weird), and
+ // (b) cudnn has special fusions for forward conv plus bias and activation,
+ // and we want to pattern-match to that after running this pass.
+ bool is_reversed_filter =
+ reverse_filter->opcode() == HloOpcode::kReverse &&
+ absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+ reverse_filter->dimensions());
+ bool is_1x1_filter =
+ absl::c_all_of(conv->window().dimensions(),
+ [](const WindowDimension& d) { return d.size() == 1; });
+ if (!is_reversed_filter &&
+ !(window_util::HasBaseDilation(conv->window()) &&
+ (reverse_filter->IsConstant() || is_1x1_filter))) {
+ VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+ "kReverse, or it's not a base-dialted conv with a 1x1 or "
+ "constant filter.";
+ return no_match_result;
}
// Match padding and dilation of the forward convolution.
@@ -417,12 +404,12 @@ MatchBackwardInput(HloInstruction* conv) {
reverse_filter->IsConstant()) {
// Create a double-reverse, which is a nop.
HloComputation* c = conv->parent();
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
- reverse_filter = c->AddInstruction(
- HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
- AsInt64Slice(kernel_spatial_dims)));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
+ reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+ reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(dnums.kernel_spatial_dimensions())));
TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
}