aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-08-30 20:38:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 20:46:31 -0700
commit423633fc4fb2b9c75f6013c3ded8eca8fe06843d (patch)
tree0195723e75fa0ef2937b7bc930ae0f5abe9571e9 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parent06ea8fb214b1b859b211ded0bbe31726214ee3f2 (diff)
[XLA] Merge kPad into kConvolution's window where possible.
This allows us to use e.g. cudnn's padding, instead of materializing a kPad instruction. PiperOrigin-RevId: 211028379
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc136
1 files changed, 134 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index a7a0044308..212ae97b59 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -2213,7 +2213,141 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
.CloneToUnique())),
{}));
}
+
const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
+ return false;
+ }
+
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
+ }
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+ }());
+
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = rhs->padding_config();
+
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
+
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
+ }
+
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
+
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
+ }
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+ }());
+
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
if (!enable_conv_simplification_) {
return Status::OK();
}
@@ -2230,8 +2364,6 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
return Status::OK();
}
- const ConvolutionDimensionNumbers& dnums =
- convolution->convolution_dimension_numbers();
const Shape& input_shape = lhs->shape();
const Shape& filter_shape = rhs->shape();
const Shape& convolution_shape = convolution->shape();