diff options
author | Justin Lebar <jlebar@google.com> | 2018-02-01 21:34:18 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-02 10:28:02 -0800 |
commit | 3be90a490c31d5a8fad70713e059bbb3e723e664 (patch) | |
tree | 8ed60399e055c1cf917fede685ff1304230801c8 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc | |
parent | d7fcf5a865570073569817fffafc07c8c74ec66d (diff) |
Internal change
PiperOrigin-RevId: 184239740
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/pad_insertion.cc | 157 |
1 files changed, 74 insertions, 83 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 2923a79af0..25846dc6cd 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -27,7 +27,7 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(HloOpcode::kConvolution, conv.opcode()); + CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -47,6 +47,12 @@ HloInstruction* MaybePaddedAndSlicedInput( window_util::HasBaseDilation(conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. + // + // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of + // moving all the padding into an explicit pad op, we should keep as much + // padding inside of cudnn as possible, on the assumption that padding + // within cudnn is basically free, whereas a kPad's cost increases as the + // amount of padding increases. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { @@ -167,14 +173,17 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { dim->set_window_dilation(1); } + // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract + // out the shape of conv_result. + Shape old_conv_shape = conv->shape().tuple_shapes(0); + VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = HloInstruction::CreateConvolve( - conv->shape(), new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers()); + auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, + new_conv_window, + conv->convolution_dimension_numbers()); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); - TF_CHECK_OK( - conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv))); + TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); return true; } @@ -190,6 +199,8 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { + CHECK_EQ(backward_conv->custom_call_target(), + kCudnnConvBackwardFilterCallTarget); if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } @@ -202,15 +213,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { @@ -222,11 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // cuDNN convolution (which doesn't support negative padding) to fail. return false; } - // If the backward convolution has uneven padding on the activations, we - // move some padding on the larger end to "internal" padding, so that the - // backward convolution produces larger weight gradients which get sliced - // later. Therefore, the amount of new padding (low or high) is the minimum - // of the amount of old padding low and old padding high. + // Compute the new, even padding for the backward conv operation. int64 new_conv_padding = std::min(padding_low, padding_high); int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -237,14 +240,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Since we move some padding from the backward convolution to the kPad, we // need to accordingly reduce the padding amount of the backward convolution // and its inner forward convolution. - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); + auto* new_dim = new_backward_conv_window.mutable_dimensions(i); + new_dim->set_padding_low(new_conv_padding); + new_dim->set_padding_high(new_conv_padding); } // Create a new backward convolution replacing the old one. @@ -260,19 +258,12 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), input, padding, input_padding_config)); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - padded_input->shape(), output->shape(), new_forward_conv_window, - forward_conv_dnums) - .ConsumeValueOrDie(), - padded_input, output, new_forward_conv_window, forward_conv_dnums)); - - // Fuse the new forward convolution to the new backward convolution. - HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, - new_backward_conv_window, backward_conv_dnums); + // The shape of the backward_conv CustomCall is a tuple (conv_result, + // scratch_buffer). Extract out the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( + backward_conv_shape, padded_input, output, new_backward_conv_window, + backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -289,14 +280,15 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return false; } - HloInstruction* forward_conv = backward_conv->fused_expression_root(); - HloInstruction* reverse_filter = forward_conv->mutable_operand(1); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); + + // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). + // Get the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + + Shape new_backward_conv_shape = backward_conv_shape; for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); @@ -315,41 +307,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( // where the amount of padding low is larger, we can canonicalize it to // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) // [A] = Slice([B A]) - // For consistency, we need to increase the low padding of the inner - // convolution by 1 as well because the input is larger now. if (padding_low > padding_high) { IncreasePaddingLowBy(padding_high - padding_low, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(padding_low - padding_high, - new_forward_conv_window.mutable_dimensions(i)); } else if (padding_low < padding_high) { IncreasePaddingHighBy(padding_low - padding_high, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(padding_high - padding_low, - new_forward_conv_window.mutable_dimensions(i)); } + // Decreasing the padding by X *increases* the size of our output by X. + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + new_backward_conv_shape.set_dimensions( + dim, new_backward_conv_shape.dimensions(dim) + + std::abs(padding_low - padding_high)); } // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_reverse_filter = - computation->AddInstruction(HloInstruction::CreateReverse( - filter->shape(), filter, reverse_filter->dimensions())); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - output->shape(), new_reverse_filter->shape(), - new_forward_conv_window, forward_conv_dnums) - .ConsumeValueOrDie(), - output, new_reverse_filter, new_forward_conv_window, - forward_conv_dnums)); + + HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( + new_backward_conv_shape, output, filter, new_backward_conv_window, + backward_conv_dnums); + + // The CustomCall created above returns a tuple (conv_result, scratch_memory). + // Extract out the two elements. HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv, new_reverse_filter}, - HloInstruction::FusionKind::kConvBackwardInput, - new_backward_conv_window, backward_conv_dnums); + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_shape, new_backward_conv_call, 0)); + HloInstruction* new_backward_conv_scratch = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_call->shape().tuple_shapes(1), + new_backward_conv_call, 1)); // Slice the new backward convolution. // @@ -377,22 +366,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( } // Replace the old backward convolution with the slice. - CHECK(ShapeUtil::Compatible( + Shape slice_shape = ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, limit_indices, strides) - .ConsumeValueOrDie(), - backward_conv->shape())); + .ConsumeValueOrDie(); + CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) + << ShapeUtil::HumanString(slice_shape) << " vs " + << ShapeUtil::HumanString(backward_conv_shape); - auto slice = - HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, - start_indices, limit_indices, strides); + HloInstruction* slice = computation->AddInstruction( + HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, + start_indices, limit_indices, strides)); + HloInstruction* new_tuple = computation->AddInstruction( + HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); VLOG(1) << "Canonicalizing backward input conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " - << slice->ToString(); + << new_tuple->ToString(); - TF_CHECK_OK( - computation->ReplaceWithNewInstruction(backward_conv, std::move(slice))); + TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); return true; } @@ -400,18 +392,17 @@ StatusOr<bool> PadInsertion::Run(HloModule* module) { bool changed = false; for (HloInstruction* instruction : module->entry_computation()->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kConvolution) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (instruction->opcode() == HloOpcode::kFusion) { - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - changed |= CanonicalizeBackwardFilterConvolution(instruction); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - changed |= CanonicalizeBackwardInputConvolution(instruction); - break; - default: - break; + if (IsCustomCallToDnnConvolution(*instruction)) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } } |