aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-09-14 16:25:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-14 16:33:08 -0700
commit57a4506da1fe74a41812a2b843c46b5fd010193d (patch)
tree142f96c5fa7371666ae2b7eda76c0b166e7427c8 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc
parentc28534e9a6a8fe59f21bb34722d933d15290c731 (diff)
Verify the output shapes of (almost) all HLO opcodes in the HloVerifier.
Previously, only the elementwise ones (approximately) were verified. As part of this change fix the newly identified brokenness. The only remaining unverified instruction is convolution which is being addressed in cl/166654245. PiperOrigin-RevId: 168763722
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc35
1 files changed, 28 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index b8c6162084..9274e16a45 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -157,15 +157,24 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
Window new_conv_window = conv->window();
for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
WindowDimension* dim = new_conv_window.mutable_dimensions(i);
+
+ // The size of the kernel may have changed so update the Window to match.
+ dim->set_size(new_kernel->shape().dimensions(
+ conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
dim->set_padding_low(0);
dim->set_padding_high(0);
dim->set_base_dilation(1);
dim->set_window_dilation(1);
}
- TF_CHECK_OK(conv->parent()->ReplaceWithNewInstruction(
- conv, HloInstruction::CreateConvolve(
- conv->shape(), new_input, new_kernel, new_conv_window,
- conv->convolution_dimension_numbers())));
+
+ VLOG(1) << "Canonicalizing forward conv";
+ auto new_conv = HloInstruction::CreateConvolve(
+ conv->shape(), new_input, new_kernel, new_conv_window,
+ conv->convolution_dimension_numbers());
+ VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
+ << new_conv->ToString();
+ TF_CHECK_OK(
+ conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv)));
return true;
}
@@ -274,6 +283,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
{new_transpose, new_forward_conv},
HloInstruction::FusionKind::kConvBackwardFilter,
new_backward_conv_window, backward_conv_dnums);
+
+ VLOG(1) << "Canonicalizing backward filter conv";
+ VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
+ << new_backward_conv->ToString();
+
TF_CHECK_OK(
computation->ReplaceInstruction(backward_conv, new_backward_conv));
return true;
@@ -379,10 +393,17 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
limit_indices, strides)
.ConsumeValueOrDie(),
backward_conv->shape()));
- TF_CHECK_OK(computation->ReplaceWithNewInstruction(
- backward_conv,
+
+ auto slice =
HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv,
- start_indices, limit_indices, strides)));
+ start_indices, limit_indices, strides);
+
+ VLOG(1) << "Canonicalizing backward input conv";
+ VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
+ << slice->ToString();
+
+ TF_CHECK_OK(
+ computation->ReplaceWithNewInstruction(backward_conv, std::move(slice)));
return true;
}