aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
blob: b42a19e3a2200e917f8040be183b8d79c9e4e161 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"

#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

namespace xla {
namespace gpu {

namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
  CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
        conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
  return window_util::HasSymmetricPadding(conv.window()) &&
         !window_util::HasNegativePadding(conv.window()) &&
         !window_util::HasDilation(conv.window());
}

// If the (positive and negative) padding on the input operand of a convolution
// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
// dilation), returns kPad and/or kSlice instructions that explicitly apply the
// padding; otherwise returns the original input operand. When there is both
// positive padding (including dilation) and negative padding, we insert both
// kPad and kSlice.
HloInstruction* MaybePaddedAndSlicedInput(
    const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums,
    HloInstruction* input) {
  HloComputation* computation = input->parent();
  if (!window_util::HasSymmetricPadding(conv_window) ||
      window_util::HasBaseDilation(conv_window)) {
    // If padding is uneven or has dilation, we insert a kPad instruction that
    // applies positive padding and dilation.
    //
    // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
    // moving all the padding into an explicit pad op, we should keep as much
    // padding inside of cudnn as possible, on the assumption that padding
    // within cudnn is basically free, whereas a kPad's cost increases as the
    // amount of padding increases.
    PaddingConfig padding_config =
        MakeNoPaddingConfig(input->shape().dimensions_size());
    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
      int64 dim = conv_dnums.input_spatial_dimensions(i);
      padding_config.mutable_dimensions(dim)->set_edge_padding_low(
          std::max<int64>(0LL, conv_window.dimensions(i).padding_low()));
      padding_config.mutable_dimensions(dim)->set_edge_padding_high(
          std::max<int64>(0LL, conv_window.dimensions(i).padding_high()));
      padding_config.mutable_dimensions(dim)->set_interior_padding(
          conv_window.dimensions(i).base_dilation() - 1);
    }
    PrimitiveType element_type = input->shape().element_type();
    HloInstruction* padding = computation->AddInstruction(
        HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
    input = MakePadHlo(input, padding, padding_config).ValueOrDie();
  }

  if (window_util::HasNegativePadding(conv_window)) {
    // If the window has negative padding, insert a kSlice that explicitly
    // applies negative padding.
    //
    // For each dimension, initialize the start index to 0 and the limit index
    // to the size of that dimension.
    std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
    std::vector<int64> limit_indices(input->shape().dimensions().begin(),
                                     input->shape().dimensions().end());
    std::vector<int64> strides(input->shape().dimensions_size(), 1);
    for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
      int64 dim = conv_dnums.input_spatial_dimensions(i);
      // If dimension "dim" has negative padding, increase the start index or
      // decrement the limit index by the amount of negative padding.
      start_indices[dim] +=
          std::max<int64>(0LL, -conv_window.dimensions(i).padding_low());
      limit_indices[dim] -=
          std::max<int64>(0LL, -conv_window.dimensions(i).padding_high());
    }

    input =
        MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
  }

  return input;
}

// If the padding on the kernel operand of a convolution can't be folded into a
// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
// explicitly applies the padding; otherwise returns the original kernel
// operand.
HloInstruction* MaybePaddedKernel(const Window& conv_window,
                                  const ConvolutionDimensionNumbers& conv_dnums,
                                  HloInstruction* kernel) {
  if (!window_util::HasWindowDilation(conv_window)) {
    return kernel;
  }

  // Compute the shape and padding config of the pad to be inserted.
  PaddingConfig padding_config;
  for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
    padding_config.add_dimensions();
  }
  for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
    int64 dim = conv_dnums.kernel_spatial_dimensions(i);
    padding_config.mutable_dimensions(dim)->set_interior_padding(
        conv_window.dimensions(i).window_dilation() - 1);
  }

  HloComputation* computation = kernel->parent();
  PrimitiveType element_type = kernel->shape().element_type();
  HloInstruction* padding = computation->AddInstruction(
      HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
  return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
}  // namespace

bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
  if (IsForwardConvolutionCanonical(*conv)) {
    return false;
  }

  // Insert slices and/or pads between the convolution and its input and/or
  // kernel operand.
  HloInstruction* new_input = MaybePaddedAndSlicedInput(
      conv->window(), conv->convolution_dimension_numbers(),
      conv->mutable_operand(0));
  HloInstruction* new_kernel =
      MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(),
                        conv->mutable_operand(1));

  // Remove the padding from convolution's window field. These paddings are
  // made explicit with the inserted pads.
  Window new_conv_window = conv->window();
  for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
    WindowDimension* dim = new_conv_window.mutable_dimensions(i);

    // The size of the kernel may have changed so update the Window to match.
    dim->set_size(new_kernel->shape().dimensions(
        conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
    dim->set_padding_low(0);
    dim->set_padding_high(0);
    dim->set_base_dilation(1);
    dim->set_window_dilation(1);
  }

  // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
  // out the shape of conv_result.
  VLOG(1) << "Canonicalizing forward conv";
  std::vector<HloInstruction*> operands(conv->operands().begin(),
                                        conv->operands().end());
  operands[0] = new_input;
  operands[1] = new_kernel;
  auto new_conv = conv->parent()->AddInstruction(
      conv->CloneWithNewOperands(conv->shape(), operands));
  new_conv->set_window(new_conv_window);
  VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
          << new_conv->ToString();
  TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
  return true;
}

namespace {
void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
  window_dim->set_padding_low(window_dim->padding_low() + delta);
}

void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
  window_dim->set_padding_high(window_dim->padding_high() + delta);
}
}  // namespace

bool PadInsertion::CanonicalizeBackwardFilterConvolution(
    HloInstruction* backward_conv) {
  CHECK_EQ(backward_conv->custom_call_target(),
           kCudnnConvBackwardFilterCallTarget);
  if (window_util::HasSymmetricPadding(backward_conv->window())) {
    return false;
  }

  // A backward filter convolution with uneven padding can be canonicalized to
  // one with even padding by padding the activations (input) beforehand. For
  // example,
  //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
  // is equivalent to
  //   ABCD0 = Pad(ABCD, padding_high=1)
  //   BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
  // We choose the lesser of padding_low and padding_high as the new padding.
  HloInstruction* input = backward_conv->mutable_operand(0);
  Window new_backward_conv_window = backward_conv->window();
  // input_padding_config is the config of the kPad to be inserted.
  PaddingConfig input_padding_config =
      MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
  ConvolutionDimensionNumbers backward_conv_dnums =
      backward_conv->convolution_dimension_numbers();
  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
    int64 padding_low = backward_conv->window().dimensions(i).padding_low();
    int64 padding_high = backward_conv->window().dimensions(i).padding_high();
    if (padding_low < 0 || padding_high < 0) {
      // TODO(b/32744257): The following canonicalization wouldn't remove
      // negative padding in a backward convolution, and would therefore cause
      // cuDNN convolution (which doesn't support negative padding) to fail.
      return false;
    }
    // Compute the new, even padding for the backward conv operation.
    int64 new_conv_padding = std::min(padding_low, padding_high);
    int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
    input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
        padding_low - new_conv_padding);
    input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
        padding_high - new_conv_padding);

    // Since we move some padding from the backward convolution to the kPad, we
    // need to accordingly reduce the padding amount of the backward convolution
    // and its inner forward convolution.
    auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
    new_dim->set_padding_low(new_conv_padding);
    new_dim->set_padding_high(new_conv_padding);
  }

  // Create a new backward convolution replacing the old one.
  HloComputation* computation = backward_conv->parent();
  HloInstruction* output = backward_conv->mutable_operand(1);
  HloInstruction* padding =
      computation->AddInstruction(HloInstruction::CreateConstant(
          LiteralUtil::Zero(input->shape().element_type())));
  HloInstruction* padded_input =
      MakePadHlo(input, padding, input_padding_config).ValueOrDie();

  // The shape of the backward_conv CustomCall is a tuple (conv_result,
  // scratch_buffer).  Extract out the shape of conv_result.
  HloInstruction* new_backward_conv =
      computation->AddInstruction(backward_conv->CloneWithNewOperands(
          backward_conv->shape(), {padded_input, output}));
  new_backward_conv->set_window(new_backward_conv_window);

  VLOG(1) << "Canonicalizing backward filter conv";
  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
          << new_backward_conv->ToString();

  TF_CHECK_OK(
      computation->ReplaceInstruction(backward_conv, new_backward_conv));
  return true;
}

bool PadInsertion::CanonicalizeBackwardInputConvolution(
    HloInstruction* backward_conv) {
  if (window_util::HasSymmetricPadding(backward_conv->window())) {
    return false;
  }

  Window new_backward_conv_window = backward_conv->window();
  ConvolutionDimensionNumbers backward_conv_dnums =
      backward_conv->convolution_dimension_numbers();

  // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
  // Get the shape of conv_result.
  Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);

  Shape new_backward_conv_shape = backward_conv_shape;
  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
    int64 padding_low = backward_conv->window().dimensions(i).padding_low();
    int64 padding_high = backward_conv->window().dimensions(i).padding_high();
    if (padding_low < 0 || padding_high < 0) {
      // TODO(b/32744257): The following canonicalization wouldn't remove
      // negative padding in a backward convolution, and would therefore cause
      // cuDNN convolution (which doesn't support negative padding) to fail.
      return false;
    }
    // If the backward convolution has uneven padding on the activations, we
    // move some padding on the larger end to "internal" padding, so that the
    // backward convolution produces larger activations which get sliced later.
    //
    // For example, suppose we have a non-canonical HLO
    //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
    // where the amount of padding low is larger, we can canonicalize it to
    //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
    //   [A] = Slice([B A])
    if (padding_low > padding_high) {
      IncreasePaddingLowBy(padding_high - padding_low,
                           new_backward_conv_window.mutable_dimensions(i));
    } else if (padding_low < padding_high) {
      IncreasePaddingHighBy(padding_low - padding_high,
                            new_backward_conv_window.mutable_dimensions(i));
    }
    // Decreasing the padding by X *increases* the size of our output by X.
    int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
    new_backward_conv_shape.set_dimensions(
        dim, new_backward_conv_shape.dimensions(dim) +
                 std::abs(padding_low - padding_high));
  }

  // Create a new backward convolution replacing the old one.
  HloComputation* computation = backward_conv->parent();
  HloInstruction* output = backward_conv->mutable_operand(0);
  HloInstruction* filter = backward_conv->mutable_operand(1);

  HloInstruction* new_backward_conv_call =
      computation->AddInstruction(backward_conv->CloneWithNewOperands(
          ShapeUtil::MakeTupleShape(
              {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
          {output, filter}));
  new_backward_conv_call->set_window(new_backward_conv_window);

  // The CustomCall created above returns a tuple (conv_result, scratch_memory).
  // Extract out the two elements.
  HloInstruction* new_backward_conv =
      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
          new_backward_conv_shape, new_backward_conv_call, 0));
  HloInstruction* new_backward_conv_scratch =
      computation->AddInstruction(HloInstruction::CreateGetTupleElement(
          new_backward_conv_call->shape().tuple_shapes(1),
          new_backward_conv_call, 1));

  // Slice the new backward convolution.
  //
  // Initialize start_indices and limit_indices as no slicing.
  std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
                                   0LL);
  std::vector<int64> limit_indices(
      new_backward_conv->shape().dimensions().begin(),
      new_backward_conv->shape().dimensions().end());
  std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
  for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
    int64 padding_low = backward_conv->window().dimensions(i).padding_low();
    int64 padding_high = backward_conv->window().dimensions(i).padding_high();
    int64 dim = backward_conv_dnums.output_spatial_dimensions(i);
    if (padding_low > padding_high) {
      // If the amount of low padding (of the old backward convolution) is
      // larger, we internally pad the low end of the activations and slice
      // internal padding out here.
      start_indices[dim] += padding_low - padding_high;
    } else if (padding_low < padding_high) {
      // If the amount of high padding is larger, we slice out the internal
      // padding on the high end.
      limit_indices[dim] -= padding_high - padding_low;
    }
  }

  // Replace the old backward convolution with the slice.
  Shape slice_shape =
      ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
                                      limit_indices, strides)
          .ConsumeValueOrDie();
  CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
      << ShapeUtil::HumanString(slice_shape) << " vs "
      << ShapeUtil::HumanString(backward_conv_shape);

  HloInstruction* slice = computation->AddInstruction(
      HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
                                  start_indices, limit_indices, strides));
  HloInstruction* new_tuple = computation->AddInstruction(
      HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));

  VLOG(1) << "Canonicalizing backward input conv";
  VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
          << new_tuple->ToString();

  TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
  return true;
}

StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
  bool changed = false;
  std::vector<HloInstruction*> convs;
  for (auto* instr : computation->instructions()) {
    if (IsCustomCallToDnnConvolution(*instr)) {
      convs.push_back(instr);
    }
  }
  for (HloInstruction* instruction : convs) {
    const auto& target = instruction->custom_call_target();
    if (target == kCudnnConvForwardCallTarget ||
        target == kCudnnConvBiasActivationForwardCallTarget) {
      changed |= CanonicalizeForwardConvolution(instruction);
    } else if (target == kCudnnConvBackwardFilterCallTarget) {
      changed |= CanonicalizeBackwardFilterConvolution(instruction);
    } else if (target == kCudnnConvBackwardInputCallTarget) {
      changed |= CanonicalizeBackwardInputConvolution(instruction);
    } else {
      LOG(FATAL) << "Unknown custom call target for cudnn conv: "
                 << instruction->ToString();
    }
  }
  return changed;
}

StatusOr<bool> PadInsertion::Run(HloModule* module) {
  bool changed = false;
  for (HloComputation* computation : module->MakeNonfusionComputations()) {
    TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
    changed |= result;
  }
  return changed;
}

}  // namespace gpu
}  // namespace xla