aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-03-08 11:49:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 11:53:40 -0800
commit214ad0978641a946c25b334c4a33ecd1793b4d70 (patch)
tree0e35b58a96b2b8f73822a1ccfb39f93d5e64d806 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc
parent52ed0eed35d782fbf13fbfbfd6a1e755c56a5f80 (diff)
Add some simple HLO creation utilities to auto-infer result shapes
I need something like this for my Gather HLO->HLO lowering pass. PiperOrigin-RevId: 188365102
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc30
1 files changed, 6 insertions, 24 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 25846dc6cd..fa405b9329 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -68,13 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(element_type))));
- input = computation->AddInstruction(HloInstruction::CreatePad(
- ShapeInference::InferPadShape(
- /*operand_shape=*/input->shape(),
- /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
- padding_config)
- .ConsumeValueOrDie(),
- input, padding, padding_config));
+ input = CreatePadHlo(input, padding, padding_config).ValueOrDie();
}
if (window_util::HasNegativePadding(conv_window)) {
@@ -97,11 +92,8 @@ HloInstruction* MaybePaddedAndSlicedInput(
std::max<int64>(0LL, -conv_window.dimensions(i).padding_high());
}
- input = computation->AddInstruction(HloInstruction::CreateSlice(
- ShapeInference::InferSliceShape(input->shape(), start_indices,
- limit_indices, strides)
- .ConsumeValueOrDie(),
- input, start_indices, limit_indices, strides));
+ input = CreateSliceHlo(input, start_indices, limit_indices, strides)
+ .ValueOrDie();
}
return input;
@@ -134,13 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(element_type))));
- return computation->AddInstruction(HloInstruction::CreatePad(
- ShapeInference::InferPadShape(
- /*operand_shape=*/kernel->shape(),
- /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
- padding_config)
- .ConsumeValueOrDie(),
- kernel, padding, padding_config));
+ return CreatePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -252,11 +238,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
computation->AddInstruction(HloInstruction::CreateConstant(
MakeUnique<Literal>(Literal::Zero(input->shape().element_type()))));
HloInstruction* padded_input =
- computation->AddInstruction(HloInstruction::CreatePad(
- ShapeInference::InferPadShape(input->shape(), padding->shape(),
- input_padding_config)
- .ConsumeValueOrDie(),
- input, padding, input_padding_config));
+ CreatePadHlo(input, padding, input_padding_config).ValueOrDie();
// The shape of the backward_conv CustomCall is a tuple (conv_result,
// scratch_buffer). Extract out the shape of conv_result.