diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/convolution_folding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/convolution_folding.cc | 443 |
1 files changed, 443 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc new file mode 100644 index 0000000000..dd1b09c6cc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -0,0 +1,443 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" + +#include <numeric> +#include <vector> + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +namespace { +// Try to match a backward filter pattern that contains "conv". +// Precondition: "conv" is a kConvolution. +std::tuple<bool, std::vector<HloInstruction*>, Window, + ConvolutionDimensionNumbers> +MatchBackwardFilter(HloInstruction* conv) { + const auto no_match_result = + std::make_tuple(false, std::vector<HloInstruction*>(), Window(), + ConvolutionDimensionNumbers()); + // Step 1: match the instruction pattern without considering the paddings and + // dimension numbers just yet. We may need some generic pattern matcher + // similar to external/llvm/include/llvm/IR/PatternMatch.h + // + // Backward filter convolution is implemented in XLA as the forward + // convolution of padded activations and dilated gradients. Padding on + // activations and dilation on gradients are specified in the "window" field + // of the forward convolution. + // + // activations gradients + // \ / + // v v + // Convolution + // conv + // | + // v + // Transpose (optional if identity transposition) + CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); + // If the forward convolution is followed by a transpose, we can fuse the + // transpose into the backward convolution as well. + HloInstruction* transpose = nullptr; + if (conv->user_count() == 1) { + HloInstruction* single_user = *conv->users().begin(); + if (single_user->opcode() == HloOpcode::kTranspose) { + transpose = single_user; + } + } + + // Step 2: match paddings and dimension numbers of the forward convolution. + const ConvolutionDimensionNumbers& conv_dnums = + conv->convolution_dimension_numbers(); + auto batch_dim = conv_dnums.batch_dimension(); + auto feature_dim = conv_dnums.feature_dimension(); + auto spatial_dims = conv_dnums.spatial_dimensions(); + + for (const WindowDimension& window_dim : conv->window().dimensions()) { + if (window_dim.stride() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have stride of 1."; + return no_match_result; + } + if (window_dim.base_dilation() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have no base (LHS) dilation."; + return no_match_result; + } + if (window_dim.padding_low() < 0) { + VLOG(1) << "Padding low should be non-negative."; + return no_match_result; + } + // Padding high will be checked in Step 3. + } + if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + VLOG(1) << conv->ToString() + << " is a regular forward convolution. No need " + "to fold it to a backward filter convolution."; + return no_match_result; + } + + // Step 3: fuse the matched HLOs into a backward convolution instruction. + // + // Compute the window of the backward convolution. + Window backward_conv_window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = backward_conv_window.add_dimensions(); + // The window size of the backward convolution equals the output size of the + // forward convolution. + int64 filter_size = conv->shape().dimensions(spatial_dims[i]); + dim->set_size(filter_size); + // The window stride equals the window dilation of the forward convolution. + dim->set_stride(conv->window().dimensions(i).window_dilation()); + // The window's low padding is the same as the low padding of the + // activations. + dim->set_padding_low(conv->window().dimensions(i).padding_low()); + + int64 input_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + int64 output_size = conv->window().dimensions(i).size(); + // Compute the range of the amount of valid high padding. We first compute + // min_padding_high, the amount of padding on the right/bottom to ensure the + // last patch ends at the border, i.e., + // + // input_size + dim->padding_low() + min_padding_high + // = (output_size - 1) * stride + filter_size + // + // Because convolution ignores trailing incomplete windows, any amount of + // padding high from min_padding_high to min_padding_high+stride-1 + // (max_padding_high) has the same effect. + int64 padded_input_size = filter_size + (output_size - 1) * dim->stride(); + int64 min_padding_high = + padded_input_size - input_size - dim->padding_low(); + int64 max_padding_high = min_padding_high + dim->stride() - 1; + CHECK_GE(dim->padding_low(), 0); + // In practice, since cuDNN convolution only supports even padding, we make + // the amount of high padding the same as the amount of low padding as long + // as it is between min_padding_high and max_padding_high. If it is not in + // that range, we pick the one that's closest to dim->padding_low() and let + // PadInsertion canonicalize the resultant backward convolution later. + // Picking the closest one minimizes the cost of the kPad instruction to be + // inserted by PadInsertion. + if (dim->padding_low() >= min_padding_high && + dim->padding_low() <= max_padding_high) { + dim->set_padding_high(dim->padding_low()); + } else { + if (dim->padding_low() < min_padding_high) { + dim->set_padding_high(min_padding_high); + } else { + dim->set_padding_high(max_padding_high); + } + } + if (dim->padding_high() < 0) { + LOG(ERROR) + << "Fusing this pattern to backward filter convolution would cause " + "negative padding (" + << dim->padding_high() + << ") on right/bottom of the weight gradients, which is not " + "supported by PadInsertion (b/32744257). Falling back to " + "unfused convolution for instruction: " + << conv->ToString(); + return no_match_result; + } + } + + // To make future HLO passes easier, we canonicalize the fused expression by + // adding an identity transposition if it's omitted in the pattern. + if (transpose == nullptr) { + // Create an identity transposition with the same rank as the forward + // convolution. + HloComputation* parent_computation = conv->parent(); + std::vector<int64> transpose_dimensions(ShapeUtil::Rank(conv->shape())); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); + transpose = + parent_computation->AddInstruction(HloInstruction::CreateTranspose( + conv->shape(), conv, transpose_dimensions)); + parent_computation->ReplaceUsesOfInstruction(conv, transpose); + } + + // Restore the dimension numbers of the backward convolution from the forward + // convolution. The two activation dimensions are reversed (batch and + // feature). + ConvolutionDimensionNumbers backward_conv_dnums; + backward_conv_dnums.set_batch_dimension(feature_dim); + backward_conv_dnums.set_feature_dimension(batch_dim); + for (int i = 0; i < 2; ++i) { + backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); + } + // The dimension numbering of the output of the forward convolution (before + // transposition) is the same as that of the activations (according to the + // semantics of kConvolution). The batch dimension of the activations should + // be treated as the input feature dimension, and the feature dimension should + // be treated as the output feature. + // + // The output of the forward convolution needs to be transposed to fit into + // the dimension numbering of the weight gradients. This transposition maps + // dimension i to PositionInContainer(transpose->dimensions(), i). + backward_conv_dnums.set_kernel_input_feature_dimension( + PositionInContainer(transpose->dimensions(), batch_dim)); + backward_conv_dnums.set_kernel_output_feature_dimension( + PositionInContainer(transpose->dimensions(), feature_dim)); + for (int i = 0; i < 2; ++i) { + backward_conv_dnums.add_kernel_spatial_dimensions( + PositionInContainer(transpose->dimensions(), spatial_dims[i])); + } + + return std::make_tuple(true, std::vector<HloInstruction*>({transpose, conv}), + backward_conv_window, backward_conv_dnums); +} + +// Try to match a backward input pattern that contains "conv". +// Precondition: "conv" is a kConvolution. +std::tuple<bool, std::vector<HloInstruction*>, Window, + ConvolutionDimensionNumbers> +MatchBackwardInput(HloInstruction* conv) { + const auto no_match_result = + std::make_tuple(false, std::vector<HloInstruction*>(), Window(), + ConvolutionDimensionNumbers()); + + // Match instruction pattern. + CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); + HloInstruction* reverse_filter = conv->mutable_operand(1); + + // Match the reverse of the filter. + ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); + const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); + if (reverse_filter->opcode() == HloOpcode::kReverse) { + if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || + !std::is_permutation(kernel_spatial_dims.begin(), + kernel_spatial_dims.end(), + reverse_filter->dimensions().begin())) { + VLOG(1) + << "Backward input convolution should reverse all kernel dimensions."; + return no_match_result; + } + } else { + // Possibly 1x1 filter. + for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { + if (conv->window().dimensions(i).size() != 1) { + VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " + << reverse_filter->ToString(); + return no_match_result; + } + } + if (!window_util::HasBaseDilation(conv->window())) { + VLOG(1) << conv->ToString() + << " is a regular forward convolution. No need " + "to fold it to a backward input convolution."; + return no_match_result; + } + } + + // Match padding and dilation of the forward convolution. + for (const WindowDimension& window_dim : conv->window().dimensions()) { + if (window_dim.stride() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have stride of 1."; + return no_match_result; + } + if (window_dim.window_dilation() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have no window dilation."; + return no_match_result; + } + } + + const auto& spatial_dims = dnums.spatial_dimensions(); + CHECK_EQ(conv->window().dimensions().size(), spatial_dims.size()); + + const Window& old_window = conv->window(); + Window new_window = old_window; + for (size_t i = 0; i < spatial_dims.size(); ++i) { + // Restore backward convolution's padding config from the matched pattern. + // See the comment in tensorflow/core/kernels/conv_grad_ops.cc + // for how we convert backward input convolution to a variant of forward + // convolution. + // + // The stride of the backward convolution + // = the base dilation factor of the forward convolution + auto dim = new_window.mutable_dimensions(i); + dim->set_stride(old_window.dimensions(i).base_dilation()); + + // The low padding = kernel_size - 1 - low padding on the gradients + // Make sure the low padding is not negative. + auto kernel_size = old_window.dimensions(i).size(); + auto backward_padding_low = + kernel_size - 1 - old_window.dimensions(i).padding_low(); + if (backward_padding_low < 0) { + LOG(ERROR) + << "The low padding of the backward convolution would be negative (" + << backward_padding_low + << "), which isn't supported by PadInsertion for now (b/32744257)."; + return no_match_result; + } + dim->set_padding_low(backward_padding_low); + + // Compute the range of the amount of padding on the right/bottom of the + // activations. XLA's convolution requires all patches to be within the + // padded base. This gives us flexiblity to choose the amount of high + // padding from a set of values without changing the result of the backward + // convolution. The minimum amount (min_padding_high) makes the last patch + // end at the border. The maximum amount (max_padding_high) equals + // min_padding_high+stride-1 -- max_padding_high+1 would cause the output + // size to change. + auto unpadded_input_size = conv->shape().dimensions(spatial_dims[i]); + auto output_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); + auto total_pad_size = padded_input_size - unpadded_input_size; + auto min_padding_high = total_pad_size - backward_padding_low; + auto max_padding_high = min_padding_high + dim->stride() - 1; + + if (backward_padding_low >= min_padding_high && + backward_padding_low <= max_padding_high) { + // In the best case (most likely), if backward_padding_low is in the range + // of the amounts of valid high padding, we choose backward_padding_low + // because cuDNN supports even padding only. + dim->set_padding_high(backward_padding_low); + } else { + // Otherwise, we choose the amount that's closest to backward_padding_low, + // and PadInsertion will later insert kSlice instructions to enforce even + // padding. + // + // For example, consider the backward convolution pattern + // + // ab xy + // | pad | reverse + // .a.b yx + // \ / + // ABC + // + // The amount of low padding on activations (in backward convolution) is + // backward_padding_low = kernel_size - 1 - forward_padding_low + // = 2 - 1 - 1 = 0 + // + // The amount of padding high must be between 1 and 2, in order to make + // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in + // the range of [1,2], so we pick the closest valid amount of padding + // high, which is 1 in this case. Therefore, we fuse the above pattern to + // + // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1) + if (backward_padding_low < min_padding_high) { + dim->set_padding_high(min_padding_high); + } else { + dim->set_padding_high(max_padding_high); + } + } + // PadInsertion doesn't handle backward input convolution with negative + // padding for now. So fall back to unfused convolution in case of negative + // padding. For example, + // ABCD = Conv(abc, reverse(xy), padding_high=2) + // could be fused to + // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) + // with positive padding low but negative padding high. + if (dim->padding_high() < 0) { + LOG(ERROR) << "Fusing this pattern to backward convolution would cause " + "negative padding (" + << dim->padding_high() + << ") on right/bottom of the activations, which is not " + "supported by PadInsertion (b/32744257). Falling back to " + "unfused convolution for instruction: " + << conv->ToString(); + return no_match_result; + } + } + + // Fuse the matched HLOs into a backward convolution instruction. + // + // If the reverse is omitted (for 1x1 filters) in the original pattern, we add + // it back in the fusion instruction so that later passes (such as + // PadInsertion) can handle such fusion instructions easily. + if (reverse_filter->opcode() != HloOpcode::kReverse) { + reverse_filter = reverse_filter->parent()->AddInstruction( + HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, + AsInt64Slice(kernel_spatial_dims))); + conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter); + } + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + return std::make_tuple(true, + std::vector<HloInstruction*>({conv, reverse_filter}), + new_window, dnums); +} +} // namespace + +StatusOr<bool> ConvolutionFolding::Run(HloModule* module) { + HloComputation* entry_computation = module->entry_computation(); + std::vector<HloInstruction*> convs; + for (const auto& hlo : entry_computation->instructions()) { + if (hlo->opcode() == HloOpcode::kConvolution) { + convs.push_back(hlo.get()); + } + } + + bool changed = false; + for (HloInstruction* conv : convs) { + bool match; + std::vector<HloInstruction*> hlos_to_fuse; + Window window; + ConvolutionDimensionNumbers dnums; + std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv); + if (match) { + VLOG(2) << "Fuse instructions"; + for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { + VLOG(2) << " " << hlo_to_fuse->ToString(); + } + HloInstruction* backward_convolution = + entry_computation->CreateFusionInstructionForBackwardConvolution( + hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter, + window, dnums); + VLOG(2) << "to backward filter convolution"; + VLOG(2) << " " << backward_convolution->ToString(); + changed = true; + continue; + } + + std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv); + if (match) { + VLOG(2) << "Fuse instructions"; + for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { + VLOG(2) << " " << hlo_to_fuse->ToString(); + } + HloInstruction* backward_convolution = + entry_computation->CreateFusionInstructionForBackwardConvolution( + hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput, + window, dnums); + VLOG(2) << "to backward input convolution"; + VLOG(2) << " " << backward_convolution->ToString(); + changed = true; + continue; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla |