diff options
author | David Majnemer <majnemer@google.com> | 2017-11-30 16:07:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-30 16:11:11 -0800 |
commit | 0438ac79bdb503ed267bec2146e7136ac8e99ff9 (patch) | |
tree | 37d99d0ab03d7f044c493af47a447c3be3d8b576 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc | |
parent | 186caed810c0e9a9ee9a3f1e0f8bea50764ce5df (diff) |
[TF:XLA] Use output spatial dimensions instead of a transpose for conv
backwards filter
PiperOrigin-RevId: 177522710
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/pad_insertion.cc | 16 |
1 files changed, 3 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 11290eda4f..c29fee0879 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* transpose = backward_conv->fused_expression_root(); - HloInstruction* forward_conv = transpose->mutable_operand(0); + HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); @@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), padded_input, output, new_forward_conv_window, forward_conv_dnums)); - HloInstruction* new_transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeInference::InferTransposeShape(new_forward_conv->shape(), - transpose->dimensions()) - .ConsumeValueOrDie(), - new_forward_conv, transpose->dimensions())); - - // Fuse the new forward convolution and the new transpose to the new backward - // convolution. + // Fuse the new forward convolution to the new backward convolution. HloInstruction* new_backward_conv = computation->CreateFusionInstructionForBackwardConvolution( - {new_transpose, new_forward_conv}, - HloInstruction::FusionKind::kConvBackwardFilter, + {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, new_backward_conv_window, backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; |