aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-09-10 18:25:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 18:29:49 -0700
commit497715e0a9bbb3c844a1902e319778cc30819f77 (patch)
tree2250688253216474b6eb2dbbd6dd6c49b665573f
parentde683c50d039676e36b6a718e4cc7ed2170a8a2f (diff)
[XLA:GPU] Don't canonicalize forward convs with constant filters to backwards conv.
No functional change. PiperOrigin-RevId: 212373345
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc302
1 files changed, 167 insertions, 135 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 2a0823aeca..c88a3a3b4b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
return scalar_add_computation_;
}
+ // Tries to fold a kPad in the input or filter into the convolution
+ // instruction's window.
+ StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
+ StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
+
+ // Tries to use a kDot in place of the given convolution.
+ StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable dot strength reduction on platforms where it causes a slowdown.
bool enable_dot_strength_reduction_;
- // Disable convolution simplification on platforms where it causes a slowdown.
+ // Disable convolution -> dot simplification on platforms where it causes a
+ // slowdown.
bool enable_conv_simplification_;
// Cached computation for adding two scalar F32.
@@ -2212,169 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
return Status::OK();
}
-Status AlgebraicSimplifierVisitor::HandleConvolution(
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
HloInstruction* convolution) {
- auto lhs = convolution->mutable_operand(0);
- auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
- ShapeUtil::IsZeroElementArray(rhs->shape())) {
- return ReplaceWithNewInstruction(
- convolution,
- HloInstruction::CreateBroadcast(
- convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::Zero(convolution->shape().element_type()))),
- {}));
- }
-
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
const auto& window = convolution->window();
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
- // Try to merge padding/dilation of the input with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
- if (lhs->opcode() != HloOpcode::kPad) {
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(lhs->operand(1), 0)) {
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
return false;
}
- const auto& padding = lhs->padding_config();
-
- // Can't pad batch or feature dims.
- for (int64 dim :
- {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
}
+ }
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = window;
- for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
- // Edge padding composes with itself in the straightforward way, but
- // composing interior padding is nontrivial, and we cowardly refuse to
- // think about it. If we see interior padding in either the kPad or conv,
- // bail if there's any sort of padding in the other.
- if (p.interior_padding() != 0 &&
- (w.padding_low() != 0 || w.padding_high() != 0 ||
- w.base_dilation() != 1)) {
- return false;
- }
- if (w.base_dilation() != 1 &&
- (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0)) {
- return false;
- }
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- w.set_padding_low(w.padding_low() + p.edge_padding_low());
- w.set_padding_high(w.padding_high() + p.edge_padding_high());
- if (p.interior_padding() != 0) {
- CHECK_EQ(w.base_dilation(), 1);
- w.set_base_dilation(1 + p.interior_padding());
- }
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs->mutable_operand(0), rhs});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
- if (folded_input_pad) {
- return Status::OK();
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
}
- // Try to merge dilation of the filter with the convolution's window.
- TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
- if (rhs->opcode() != HloOpcode::kPad) {
- return false;
- }
+ const auto& padding = rhs->padding_config();
- // Convolution's padding is always zero, so bail if the kPad is adding
- // something other than zero.
- if (!IsAll(rhs->operand(1), 0)) {
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
return false;
}
+ }
- const auto& padding = rhs->padding_config();
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
- // Can't pad or dilate feature dims.
- for (int64 dim : {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()}) {
- const auto& p = padding.dimensions(dim);
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
- p.interior_padding() != 0) {
- return false;
- }
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
}
- // Compute the window which is the result of merging the kPad and the
- // convolution's existing window.
- Window new_window = convolution->window();
- for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
- auto& w = *new_window.mutable_dimensions(dim);
- const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
-
- // We can only do this transformation if p adds dilation to the filter --
- // edge padding on the filter is not supported in conv.
- if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
- return false;
- }
-
- // Nothing to do if the kPad for this dim is entirely a nop.
- if (p.interior_padding() == 0) {
- continue;
- }
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
- // We cowardly refuse to think about how dilation composes with itself;
- // bail if both the kPad and conv have dilation on this dimension.
- if (w.window_dilation() > 1) {
- return false;
- }
- CHECK_EQ(w.window_dilation(), 1);
- w.set_window_dilation(1 + p.interior_padding());
- w.set_size(rhs->operand(0)->shape().dimensions(
- dnums.kernel_spatial_dimensions(dim)));
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
}
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
- auto new_conv = convolution->CloneWithNewOperands(
- convolution->shape(), {lhs, rhs->mutable_operand(0)});
- new_conv->set_window(new_window);
- TF_RETURN_IF_ERROR(
- ReplaceWithNewInstruction(convolution, std::move(new_conv)));
- return true;
- }());
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+}
- if (folded_filter_pad) {
- return Status::OK();
- }
+StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
+ HloInstruction* convolution) {
+ auto* lhs = convolution->mutable_operand(0);
+ auto* rhs = convolution->mutable_operand(1);
+ const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
if (!enable_conv_simplification_) {
- return Status::OK();
+ return false;
}
- // HandleConvolution tries to replace a convolution with a DOT instruction.
- //
- // Only add when bitcasts can be used:
- // - if bitcasts are not supported, then reshapes could be used but will
- // end up with another copy.
- // - if bitcasts are supported, the simplifier will be called again with
- // bitcasts_ == true.
- // TODO(cwhipkey): b/31337498, make this layout insensitive.
+ // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
+ // layout-insensitive mode, for fear of adding nontrivial reshapes.
if (!is_layout_sensitive_) {
- return Status::OK();
+ return false;
}
const Shape& input_shape = lhs->shape();
@@ -2387,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// Require the spatial dimensions in the kernel to have a bound of one.
for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
- return Status::OK();
+ return false;
}
}
@@ -2398,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// for a 1x1 window, so window dilation is no problem.
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
window_util::HasBaseDilation(window)) {
- return Status::OK();
+ return false;
}
// Also, the shapes must align for a rowmajor matmul:
@@ -2424,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dnums.kernel_input_feature_dimension()) <
PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
dnums.kernel_output_feature_dimension()))) {
- return Status::OK();
+ return false;
}
auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
@@ -2466,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
!valid_bitcast_callback_(filter_shape, new_filter_shape) ||
!valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
- return Status::OK();
+ return false;
}
auto new_lhs = add_bitcast(new_input_shape, lhs);
@@ -2478,7 +2473,44 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
convolution->precision_config()));
- return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+ HloInstruction* convolution) {
+ // Zero-sized input or filter.
+ if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
+ return ReplaceWithNewInstruction(
+ convolution,
+ HloInstruction::CreateBroadcast(
+ convolution->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(convolution->shape().element_type()))),
+ {}));
+ }
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
+ // Try to replace the convolution with a kDot instruction.
+ TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
+ if (replaced_with_dot) {
+ return Status::OK();
+ }
+
+ return Status::OK();
}
bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(