aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-11-30 16:07:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 16:11:11 -0800
commit0438ac79bdb503ed267bec2146e7136ac8e99ff9 (patch)
tree37d99d0ab03d7f044c493af47a447c3be3d8b576 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc
parent186caed810c0e9a9ee9a3f1e0f8bea50764ce5df (diff)
[TF:XLA] Use output spatial dimensions instead of a transpose for conv
backwards filter PiperOrigin-RevId: 177522710
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc16
1 files changed, 3 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 11290eda4f..c29fee0879 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// ABCD0 = Pad(ABCD, padding_high=1)
// BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
// We choose the lesser of padding_low and padding_high as the new padding.
- HloInstruction* transpose = backward_conv->fused_expression_root();
- HloInstruction* forward_conv = transpose->mutable_operand(0);
+ HloInstruction* forward_conv = backward_conv->fused_expression_root();
HloInstruction* input = backward_conv->mutable_operand(0);
Window new_forward_conv_window = forward_conv->window();
Window new_backward_conv_window = backward_conv->window();
@@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
.ConsumeValueOrDie(),
padded_input, output, new_forward_conv_window, forward_conv_dnums));
- HloInstruction* new_transpose =
- computation->AddInstruction(HloInstruction::CreateTranspose(
- ShapeInference::InferTransposeShape(new_forward_conv->shape(),
- transpose->dimensions())
- .ConsumeValueOrDie(),
- new_forward_conv, transpose->dimensions()));
-
- // Fuse the new forward convolution and the new transpose to the new backward
- // convolution.
+ // Fuse the new forward convolution to the new backward convolution.
HloInstruction* new_backward_conv =
computation->CreateFusionInstructionForBackwardConvolution(
- {new_transpose, new_forward_conv},
- HloInstruction::FusionKind::kConvBackwardFilter,
+ {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter,
new_backward_conv_window, backward_conv_dnums);
VLOG(1) << "Canonicalizing backward filter conv";