diff options
author | Mark Heffernan <meheff@google.com> | 2017-09-14 16:25:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-14 16:33:08 -0700 |
commit | 57a4506da1fe74a41812a2b843c46b5fd010193d (patch) | |
tree | 142f96c5fa7371666ae2b7eda76c0b166e7427c8 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc | |
parent | c28534e9a6a8fe59f21bb34722d933d15290c731 (diff) |
Verify the output shapes of (almost) all HLO opcodes in the HloVerifier.
Previously, only the elementwise ones (approximately) were verified. As part
of this change fix the newly identified brokenness. The only remaining
unverified instruction is convolution which is being addressed in cl/166654245.
PiperOrigin-RevId: 168763722
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/pad_insertion.cc | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b8c6162084..9274e16a45 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -157,15 +157,24 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { Window new_conv_window = conv->window(); for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { WindowDimension* dim = new_conv_window.mutable_dimensions(i); + + // The size of the kernel may have changed so update the Window to match. + dim->set_size(new_kernel->shape().dimensions( + conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); dim->set_padding_low(0); dim->set_padding_high(0); dim->set_base_dilation(1); dim->set_window_dilation(1); } - TF_CHECK_OK(conv->parent()->ReplaceWithNewInstruction( - conv, HloInstruction::CreateConvolve( - conv->shape(), new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers()))); + + VLOG(1) << "Canonicalizing forward conv"; + auto new_conv = HloInstruction::CreateConvolve( + conv->shape(), new_input, new_kernel, new_conv_window, + conv->convolution_dimension_numbers()); + VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " + << new_conv->ToString(); + TF_CHECK_OK( + conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv))); return true; } @@ -274,6 +283,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( {new_transpose, new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, new_backward_conv_window, backward_conv_dnums); + + VLOG(1) << "Canonicalizing backward filter conv"; + VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " + << new_backward_conv->ToString(); + TF_CHECK_OK( computation->ReplaceInstruction(backward_conv, new_backward_conv)); return true; @@ -379,10 +393,17 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( limit_indices, strides) .ConsumeValueOrDie(), backward_conv->shape())); - TF_CHECK_OK(computation->ReplaceWithNewInstruction( - backward_conv, + + auto slice = HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, - start_indices, limit_indices, strides))); + start_indices, limit_indices, strides); + + VLOG(1) << "Canonicalizing backward input conv"; + VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " + << slice->ToString(); + + TF_CHECK_OK( + computation->ReplaceWithNewInstruction(backward_conv, std::move(slice))); return true; } |