diff options
author | Justin Lebar <jlebar@google.com> | 2018-09-10 18:25:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 18:29:49 -0700 |
commit | 497715e0a9bbb3c844a1902e319778cc30819f77 (patch) | |
tree | 2250688253216474b6eb2dbbd6dd6c49b665573f /tensorflow/compiler/xla/service/algebraic_simplifier.cc | |
parent | de683c50d039676e36b6a718e4cc7ed2170a8a2f (diff) |
[XLA:GPU] Don't canonicalize forward convs with constant filters to backwards conv.
No functional change.
PiperOrigin-RevId: 212373345
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 302 |
1 files changed, 167 insertions, 135 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 2a0823aeca..c88a3a3b4b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return scalar_add_computation_; } + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + StatusOr<bool> FoldConvInputPad(HloInstruction* convolution); + StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution); + + // Tries to use a kDot in place of the given convolution. + StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplification on platforms where it causes a slowdown. + // Disable convolution -> dot simplification on platforms where it causes a + // slowdown. bool enable_conv_simplification_; // Cached computation for adding two scalar F32. @@ -2212,169 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleConvolution( +StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { - auto lhs = convolution->mutable_operand(0); - auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::IsZeroElementArray(lhs->shape()) || - ShapeUtil::IsZeroElementArray(rhs->shape())) { - return ReplaceWithNewInstruction( - convolution, - HloInstruction::CreateBroadcast( - convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()))), - {})); - } - + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); const auto& window = convolution->window(); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - // Try to merge padding/dilation of the input with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> { - if (lhs->opcode() != HloOpcode::kPad) { + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(lhs->operand(1), 0)) { + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { return false; } - const auto& padding = lhs->padding_config(); - - // Can't pad batch or feature dims. - for (int64 dim : - {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); } + } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = window; - for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); - // Edge padding composes with itself in the straightforward way, but - // composing interior padding is nontrivial, and we cowardly refuse to - // think about it. If we see interior padding in either the kPad or conv, - // bail if there's any sort of padding in the other. - if (p.interior_padding() != 0 && - (w.padding_low() != 0 || w.padding_high() != 0 || - w.base_dilation() != 1)) { - return false; - } - if (w.base_dilation() != 1 && - (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0)) { - return false; - } + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - w.set_padding_low(w.padding_low() + p.edge_padding_low()); - w.set_padding_high(w.padding_high() + p.edge_padding_high()); - if (p.interior_padding() != 0) { - CHECK_EQ(w.base_dilation(), 1); - w.set_base_dilation(1 + p.interior_padding()); - } - } +StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs->mutable_operand(0), rhs}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } - if (folded_input_pad) { - return Status::OK(); + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; } - // Try to merge dilation of the filter with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> { - if (rhs->opcode() != HloOpcode::kPad) { - return false; - } + const auto& padding = rhs->padding_config(); - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(rhs->operand(1), 0)) { + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - const auto& padding = rhs->padding_config(); + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - // Can't pad or dilate feature dims. - for (int64 dim : {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = convolution->window(); - for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - - // We can only do this transformation if p adds dilation to the filter -- - // edge padding on the filter is not supported in conv. - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { - return false; - } - - // Nothing to do if the kPad for this dim is entirely a nop. - if (p.interior_padding() == 0) { - continue; - } + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } - // We cowardly refuse to think about how dilation composes with itself; - // bail if both the kPad and conv have dilation on this dimension. - if (w.window_dilation() > 1) { - return false; - } - CHECK_EQ(w.window_dilation(), 1); - w.set_window_dilation(1 + p.interior_padding()); - w.set_size(rhs->operand(0)->shape().dimensions( - dnums.kernel_spatial_dimensions(dim))); + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); + } - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs, rhs->mutable_operand(0)}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - if (folded_filter_pad) { - return Status::OK(); - } +StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); if (!enable_conv_simplification_) { - return Status::OK(); + return false; } - // HandleConvolution tries to replace a convolution with a DOT instruction. - // - // Only add when bitcasts can be used: - // - if bitcasts are not supported, then reshapes could be used but will - // end up with another copy. - // - if bitcasts are supported, the simplifier will be called again with - // bitcasts_ == true. - // TODO(cwhipkey): b/31337498, make this layout insensitive. + // TODO(b/31337498): For now, we cowardly refuse to do this optimization in + // layout-insensitive mode, for fear of adding nontrivial reshapes. if (!is_layout_sensitive_) { - return Status::OK(); + return false; } const Shape& input_shape = lhs->shape(); @@ -2387,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Require the spatial dimensions in the kernel to have a bound of one. for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { - return Status::OK(); + return false; } } @@ -2398,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { - return Status::OK(); + return false; } // Also, the shapes must align for a rowmajor matmul: @@ -2424,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dnums.kernel_input_feature_dimension()) < PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { - return Status::OK(); + return false; } auto add_bitcast = [&](Shape shape, HloInstruction* operand) { @@ -2466,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( if (!valid_bitcast_callback_(input_shape, new_input_shape) || !valid_bitcast_callback_(filter_shape, new_filter_shape) || !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { - return Status::OK(); + return false; } auto new_lhs = add_bitcast(new_input_shape, lhs); @@ -2478,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, convolution->precision_config())); - return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); + return true; +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution) { + // Zero-sized input or filter. + if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()))), + {})); + } + + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + if (folded_filter_pad) { + return Status::OK(); + } + + // Try to replace the convolution with a kDot instruction. + TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + if (replaced_with_dot) { + return Status::OK(); + } + + return Status::OK(); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( |