aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-02-01 21:34:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 10:28:02 -0800
commit3be90a490c31d5a8fad70713e059bbb3e723e664 (patch)
tree8ed60399e055c1cf917fede685ff1304230801c8 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc
parentd7fcf5a865570073569817fffafc07c8c74ec66d (diff)
Internal change
PiperOrigin-RevId: 184239740
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc157
1 files changed, 74 insertions, 83 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 2923a79af0..25846dc6cd 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -27,7 +27,7 @@ namespace gpu {
namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
- CHECK_EQ(HloOpcode::kConvolution, conv.opcode());
+ CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
return window_util::HasSymmetricPadding(conv.window()) &&
!window_util::HasNegativePadding(conv.window()) &&
!window_util::HasDilation(conv.window());
@@ -47,6 +47,12 @@ HloInstruction* MaybePaddedAndSlicedInput(
window_util::HasBaseDilation(conv_window)) {
// If padding is uneven or has dilation, we insert a kPad instruction that
// applies positive padding and dilation.
+ //
+ // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
+ // moving all the padding into an explicit pad op, we should keep as much
+ // padding inside of cudnn as possible, on the assumption that padding
+ // within cudnn is basically free, whereas a kPad's cost increases as the
+ // amount of padding increases.
PaddingConfig padding_config =
MakeNoPaddingConfig(input->shape().dimensions_size());
for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
@@ -167,14 +173,17 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
dim->set_window_dilation(1);
}
+ // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
+ // out the shape of conv_result.
+ Shape old_conv_shape = conv->shape().tuple_shapes(0);
+
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = HloInstruction::CreateConvolve(
- conv->shape(), new_input, new_kernel, new_conv_window,
- conv->convolution_dimension_numbers());
+ auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
+ new_conv_window,
+ conv->convolution_dimension_numbers());
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
- TF_CHECK_OK(
- conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv)));
+ TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
return true;
}
@@ -190,6 +199,8 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
bool PadInsertion::CanonicalizeBackwardFilterConvolution(
HloInstruction* backward_conv) {
+ CHECK_EQ(backward_conv->custom_call_target(),
+ kCudnnConvBackwardFilterCallTarget);
if (window_util::HasSymmetricPadding(backward_conv->window())) {
return false;
}
@@ -202,15 +213,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// ABCD0 = Pad(ABCD, padding_high=1)
// BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
// We choose the lesser of padding_low and padding_high as the new padding.
- HloInstruction* forward_conv = backward_conv->fused_expression_root();
HloInstruction* input = backward_conv->mutable_operand(0);
- Window new_forward_conv_window = forward_conv->window();
Window new_backward_conv_window = backward_conv->window();
// input_padding_config is the config of the kPad to be inserted.
PaddingConfig input_padding_config =
MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
- ConvolutionDimensionNumbers forward_conv_dnums =
- forward_conv->convolution_dimension_numbers();
ConvolutionDimensionNumbers backward_conv_dnums =
backward_conv->convolution_dimension_numbers();
for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
@@ -222,11 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// cuDNN convolution (which doesn't support negative padding) to fail.
return false;
}
- // If the backward convolution has uneven padding on the activations, we
- // move some padding on the larger end to "internal" padding, so that the
- // backward convolution produces larger weight gradients which get sliced
- // later. Therefore, the amount of new padding (low or high) is the minimum
- // of the amount of old padding low and old padding high.
+ // Compute the new, even padding for the backward conv operation.
int64 new_conv_padding = std::min(padding_low, padding_high);
int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
@@ -237,14 +240,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Since we move some padding from the backward convolution to the kPad, we
// need to accordingly reduce the padding amount of the backward convolution
// and its inner forward convolution.
- IncreasePaddingLowBy(-(padding_low - new_conv_padding),
- new_backward_conv_window.mutable_dimensions(i));
- IncreasePaddingHighBy(-(padding_high - new_conv_padding),
- new_backward_conv_window.mutable_dimensions(i));
- IncreasePaddingLowBy(-(padding_low - new_conv_padding),
- new_forward_conv_window.mutable_dimensions(i));
- IncreasePaddingHighBy(-(padding_high - new_conv_padding),
- new_forward_conv_window.mutable_dimensions(i));
+ auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
+ new_dim->set_padding_low(new_conv_padding);
+ new_dim->set_padding_high(new_conv_padding);
}
// Create a new backward convolution replacing the old one.
@@ -260,19 +258,12 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
.ConsumeValueOrDie(),
input, padding, input_padding_config));
- HloInstruction* new_forward_conv =
- computation->AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(
- padded_input->shape(), output->shape(), new_forward_conv_window,
- forward_conv_dnums)
- .ConsumeValueOrDie(),
- padded_input, output, new_forward_conv_window, forward_conv_dnums));
-
- // Fuse the new forward convolution to the new backward convolution.
- HloInstruction* new_backward_conv =
- computation->CreateFusionInstructionForBackwardConvolution(
- {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter,
- new_backward_conv_window, backward_conv_dnums);
+ // The shape of the backward_conv CustomCall is a tuple (conv_result,
+ // scratch_buffer). Extract out the shape of conv_result.
+ Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
+ HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
+ backward_conv_shape, padded_input, output, new_backward_conv_window,
+ backward_conv_dnums);
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -289,14 +280,15 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
return false;
}
- HloInstruction* forward_conv = backward_conv->fused_expression_root();
- HloInstruction* reverse_filter = forward_conv->mutable_operand(1);
- Window new_forward_conv_window = forward_conv->window();
Window new_backward_conv_window = backward_conv->window();
- ConvolutionDimensionNumbers forward_conv_dnums =
- forward_conv->convolution_dimension_numbers();
ConvolutionDimensionNumbers backward_conv_dnums =
backward_conv->convolution_dimension_numbers();
+
+ // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
+ // Get the shape of conv_result.
+ Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
+
+ Shape new_backward_conv_shape = backward_conv_shape;
for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
int64 padding_low = backward_conv->window().dimensions(i).padding_low();
int64 padding_high = backward_conv->window().dimensions(i).padding_high();
@@ -315,41 +307,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
// where the amount of padding low is larger, we can canonicalize it to
// [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
// [A] = Slice([B A])
- // For consistency, we need to increase the low padding of the inner
- // convolution by 1 as well because the input is larger now.
if (padding_low > padding_high) {
IncreasePaddingLowBy(padding_high - padding_low,
new_backward_conv_window.mutable_dimensions(i));
- IncreasePaddingLowBy(padding_low - padding_high,
- new_forward_conv_window.mutable_dimensions(i));
} else if (padding_low < padding_high) {
IncreasePaddingHighBy(padding_low - padding_high,
new_backward_conv_window.mutable_dimensions(i));
- IncreasePaddingHighBy(padding_high - padding_low,
- new_forward_conv_window.mutable_dimensions(i));
}
+ // Decreasing the padding by X *increases* the size of our output by X.
+ int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
+ new_backward_conv_shape.set_dimensions(
+ dim, new_backward_conv_shape.dimensions(dim) +
+ std::abs(padding_low - padding_high));
}
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(0);
HloInstruction* filter = backward_conv->mutable_operand(1);
- HloInstruction* new_reverse_filter =
- computation->AddInstruction(HloInstruction::CreateReverse(
- filter->shape(), filter, reverse_filter->dimensions()));
- HloInstruction* new_forward_conv =
- computation->AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(
- output->shape(), new_reverse_filter->shape(),
- new_forward_conv_window, forward_conv_dnums)
- .ConsumeValueOrDie(),
- output, new_reverse_filter, new_forward_conv_window,
- forward_conv_dnums));
+
+ HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
+ new_backward_conv_shape, output, filter, new_backward_conv_window,
+ backward_conv_dnums);
+
+ // The CustomCall created above returns a tuple (conv_result, scratch_memory).
+ // Extract out the two elements.
HloInstruction* new_backward_conv =
- computation->CreateFusionInstructionForBackwardConvolution(
- {new_forward_conv, new_reverse_filter},
- HloInstruction::FusionKind::kConvBackwardInput,
- new_backward_conv_window, backward_conv_dnums);
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_backward_conv_shape, new_backward_conv_call, 0));
+ HloInstruction* new_backward_conv_scratch =
+ computation->AddInstruction(HloInstruction::CreateGetTupleElement(
+ new_backward_conv_call->shape().tuple_shapes(1),
+ new_backward_conv_call, 1));
// Slice the new backward convolution.
//
@@ -377,22 +366,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
}
// Replace the old backward convolution with the slice.
- CHECK(ShapeUtil::Compatible(
+ Shape slice_shape =
ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
limit_indices, strides)
- .ConsumeValueOrDie(),
- backward_conv->shape()));
+ .ConsumeValueOrDie();
+ CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
+ << ShapeUtil::HumanString(slice_shape) << " vs "
+ << ShapeUtil::HumanString(backward_conv_shape);
- auto slice =
- HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv,
- start_indices, limit_indices, strides);
+ HloInstruction* slice = computation->AddInstruction(
+ HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
+ start_indices, limit_indices, strides));
+ HloInstruction* new_tuple = computation->AddInstruction(
+ HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
VLOG(1) << "Canonicalizing backward input conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
- << slice->ToString();
+ << new_tuple->ToString();
- TF_CHECK_OK(
- computation->ReplaceWithNewInstruction(backward_conv, std::move(slice)));
+ TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
return true;
}
@@ -400,18 +392,17 @@ StatusOr<bool> PadInsertion::Run(HloModule* module) {
bool changed = false;
for (HloInstruction* instruction :
module->entry_computation()->MakeInstructionPostOrder()) {
- if (instruction->opcode() == HloOpcode::kConvolution) {
- changed |= CanonicalizeForwardConvolution(instruction);
- } else if (instruction->opcode() == HloOpcode::kFusion) {
- switch (instruction->fusion_kind()) {
- case HloInstruction::FusionKind::kConvBackwardFilter:
- changed |= CanonicalizeBackwardFilterConvolution(instruction);
- break;
- case HloInstruction::FusionKind::kConvBackwardInput:
- changed |= CanonicalizeBackwardInputConvolution(instruction);
- break;
- default:
- break;
+ if (IsCustomCallToDnnConvolution(*instruction)) {
+ const auto& target = instruction->custom_call_target();
+ if (target == kCudnnConvForwardCallTarget) {
+ changed |= CanonicalizeForwardConvolution(instruction);
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ changed |= CanonicalizeBackwardFilterConvolution(instruction);
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ changed |= CanonicalizeBackwardInputConvolution(instruction);
+ } else {
+ LOG(FATAL) << "Unknown custom call target for cudnn conv: "
+ << instruction->ToString();
}
}
}