aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-08-30 20:38:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 20:46:31 -0700
commit423633fc4fb2b9c75f6013c3ded8eca8fe06843d (patch)
tree0195723e75fa0ef2937b7bc930ae0f5abe9571e9
parent06ea8fb214b1b859b211ded0bbe31726214ee3f2 (diff)
[XLA] Merge kPad into kConvolution's window where possible.
This allows us to use e.g. cudnn's padding, instead of materializing a kPad instruction. PiperOrigin-RevId: 211028379
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc136
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc258
2 files changed, 392 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index a7a0044308..212ae97b59 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -2213,7 +2213,141 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
.CloneToUnique())),
{}));
}
+
const auto& window = convolution->window();
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
+ // Try to merge padding/dilation of the input with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
+ if (lhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(lhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = lhs->padding_config();
+
+ // Can't pad batch or feature dims.
+ for (int64 dim :
+ {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = window;
+ for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+ // Edge padding composes with itself in the straightforward way, but
+ // composing interior padding is nontrivial, and we cowardly refuse to
+ // think about it. If we see interior padding in either the kPad or conv,
+ // bail if there's any sort of padding in the other.
+ if (p.interior_padding() != 0 &&
+ (w.padding_low() != 0 || w.padding_high() != 0 ||
+ w.base_dilation() != 1)) {
+ return false;
+ }
+ if (w.base_dilation() != 1 &&
+ (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0)) {
+ return false;
+ }
+
+ w.set_padding_low(w.padding_low() + p.edge_padding_low());
+ w.set_padding_high(w.padding_high() + p.edge_padding_high());
+ if (p.interior_padding() != 0) {
+ CHECK_EQ(w.base_dilation(), 1);
+ w.set_base_dilation(1 + p.interior_padding());
+ }
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs->mutable_operand(0), rhs});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+ }());
+
+ if (folded_input_pad) {
+ return Status::OK();
+ }
+
+ // Try to merge dilation of the filter with the convolution's window.
+ TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
+ if (rhs->opcode() != HloOpcode::kPad) {
+ return false;
+ }
+
+ // Convolution's padding is always zero, so bail if the kPad is adding
+ // something other than zero.
+ if (!IsAll(rhs->operand(1), 0)) {
+ return false;
+ }
+
+ const auto& padding = rhs->padding_config();
+
+ // Can't pad or dilate feature dims.
+ for (int64 dim : {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()}) {
+ const auto& p = padding.dimensions(dim);
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+ p.interior_padding() != 0) {
+ return false;
+ }
+ }
+
+ // Compute the window which is the result of merging the kPad and the
+ // convolution's existing window.
+ Window new_window = convolution->window();
+ for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+ auto& w = *new_window.mutable_dimensions(dim);
+ const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
+
+ // We can only do this transformation if p adds dilation to the filter --
+ // edge padding on the filter is not supported in conv.
+ if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+ return false;
+ }
+
+ // Nothing to do if the kPad for this dim is entirely a nop.
+ if (p.interior_padding() == 0) {
+ continue;
+ }
+
+ // We cowardly refuse to think about how dilation composes with itself;
+ // bail if both the kPad and conv have dilation on this dimension.
+ if (w.window_dilation() > 1) {
+ return false;
+ }
+ CHECK_EQ(w.window_dilation(), 1);
+ w.set_window_dilation(1 + p.interior_padding());
+ w.set_size(rhs->operand(0)->shape().dimensions(
+ dnums.kernel_spatial_dimensions(dim)));
+ }
+
+ auto new_conv = convolution->CloneWithNewOperands(
+ convolution->shape(), {lhs, rhs->mutable_operand(0)});
+ new_conv->set_window(new_window);
+ TF_RETURN_IF_ERROR(
+ ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+ return true;
+ }());
+
+ if (folded_filter_pad) {
+ return Status::OK();
+ }
+
if (!enable_conv_simplification_) {
return Status::OK();
}
@@ -2230,8 +2364,6 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
return Status::OK();
}
- const ConvolutionDimensionNumbers& dnums =
- convolution->convolution_dimension_numbers();
const Shape& input_shape = lhs->shape();
const Shape& filter_shape = rhs->shape();
const Shape& convolution_shape = convolution->shape();
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 182c581ad8..b4ff048db0 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2123,6 +2123,264 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values));
}
+// Used for TEST_Ps that test merging (or not) of a kPad instruction into a
+// convolution's Window.
+struct ConvPaddingTestcase {
+ ConvPaddingTestcase(absl::string_view padding,
+ absl::string_view orig_conv_window,
+ absl::string_view expected_conv_window)
+ : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
+ /*pad_value=*/0) {}
+
+ ConvPaddingTestcase(absl::string_view padding,
+ absl::string_view orig_conv_window,
+ absl::string_view expected_conv_window, float pad_value)
+ : padding(padding),
+ orig_conv_window(orig_conv_window),
+ expected_conv_window(expected_conv_window),
+ pad_value(pad_value) {}
+
+ string ToString() const {
+ return absl::StrFormat(
+ "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
+ "pad_value=%f",
+ padding, orig_conv_window, expected_conv_window, pad_value);
+ }
+
+ string padding;
+ string orig_conv_window;
+ string expected_conv_window;
+ float pad_value;
+};
+
+// ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
+// computation that does
+//
+// conv(pad(param0, padding=padding), param1), window=orig_conv_window
+//
+// gets transformed by AlgebraicSimplifier to
+//
+// conv(param0, param1), window=expected_conv_window
+//
+// or, if expected_conv_window is the empty string, checks that
+// AlgebraicSimplifier does *not* transform the original convolution.
+class ConvInputPaddingTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<ConvPaddingTestcase> {};
+
+INSTANTIATE_TEST_CASE_P(
+ ConvInputPaddingTestCases, ConvInputPaddingTest,
+ ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
+ // Merge this edge padding into the conv.
+ {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
+ // Merge this edge padding with the conv's edge padding.
+ {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
+ // Merge this interior-padded kPad with the unpadded conv. The 3x6
+ // interior padding gets transformed to 4x7 conv lhs dilation.
+ {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
+ // kPad has dilation on one dim, conv has it on the other; merge them.
+ {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
+ // kPad has dilation and edge padding on one dim, conv has them on the
+ // other; merge them.
+ {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
+ "pad=0_1x3_0 lhs_dilate=2x10"},
+
+ // Don't transform if the pad value is nonzero.
+ {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
+
+ // We refuse to transform the following because on some dimension, one
+ // of the kPad and conv has dilation and the other has some sort of
+ // padding.
+ {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
+ {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
+ {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
+ {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
+
+ // We can't merge feature or batch padding into the conv.
+ {"1_0x0_0x0_0x0_0", "", ""},
+ {"0_0x1_0x0_0x0_0", "", ""},
+ }));
+
+TEST_P(ConvInputPaddingTest, DoTest) {
+ ConvPaddingTestcase testcase = GetParam();
+
+ // It would be better to put the testcase's ToString into the test name, but
+ // gUnit has constraints on what can go into test names, and any reasonable
+ // implementation of ToString() seems to violate them.
+ SCOPED_TRACE(testcase.ToString());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01
+ "input"));
+ auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(testcase.pad_value)));
+
+ PaddingConfig padding_config =
+ ParsePaddingConfig(testcase.padding).ValueOrDie();
+ auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
+ padding_config)
+ .ValueOrDie(),
+ input, pad_value, padding_config));
+
+ auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1,
+ ShapeUtil::MakeShape(
+ F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01
+ "input"));
+
+ ConvolutionDimensionNumbers dnums =
+ ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
+ Window window =
+ ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
+ .ValueOrDie();
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
+ window, dnums)
+ .ValueOrDie(),
+ lhs_pad, filter, window, dnums));
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ if (testcase.expected_conv_window.empty()) {
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
+ } else {
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ auto* conv = module->entry_computation()->root_instruction();
+ SCOPED_TRACE(module->ToString());
+ ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter()));
+ EXPECT_EQ(window_util::ToString(conv->window()),
+ absl::StrCat("size=3x3 ", testcase.expected_conv_window));
+ }
+}
+
+// ConvFilterPaddingTest (and its one associated TEST_P) checks that a
+// computation that does
+//
+// conv(param0, pad(param1, padding=padding)), window=orig_conv_window
+//
+// gets transformed by AlgebraicSimplifier to
+//
+// conv(param0, param1), window=expected_conv_window
+//
+// or, if expected_conv_window is the empty string, checks that
+// AlgebraicSimplifier does *not* transform the original convolution.
+class ConvFilterPaddingTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<ConvPaddingTestcase> {};
+
+INSTANTIATE_TEST_CASE_P(
+ ConvFilterPaddingTestCases, ConvFilterPaddingTest,
+ ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
+ // Can only merge interior padding on the filter's spatial dimensions;
+ // all
+ // other paddings (edge padding and interior padding on the channel
+ // dims)
+ // should be rejected out of hand.
+ {"1_0_0x0_0_0x0_0x0_0", "", ""},
+ {"0_1_0x0_0_0x0_0x0_0", "", ""},
+ {"0_0_1x0_0_0x0_0x0_0", "", ""},
+ {"0_0_0x1_0_0x0_0x0_0", "", ""},
+ {"0_0_0x0_1_0x0_0x0_0", "", ""},
+ {"0_0_0x0_0_1x0_0x0_0", "", ""},
+ {"0_0_0x0_0_0x1_0x0_0", "", ""},
+ {"0_0_0x0_0_0x0_1x0_0", "", ""},
+ {"0_0_0x0_0_0x0_0x1_0", "", ""},
+ {"0_0_0x0_0_0x0_0x0_1", "", ""},
+
+ // Interior padding on channel dims can be merged into the conv, so long
+ // as the conv and pad don't have interior padding on the same dim.
+ {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
+ {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
+ {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
+ {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
+ {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
+
+ // Can't merge if for a given dim there's interior padding on both the
+ // pad and conv.
+ {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
+ {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
+
+ // Don't transform if the pad value is nonzero.
+ {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
+ }));
+
+TEST_P(ConvFilterPaddingTest, DoIt) {
+ ConvPaddingTestcase testcase = GetParam();
+
+ // It would be better to put the testcase's ToString into the test name, but
+ // gUnit has constraints on what can go into test names, and any reasonable
+ // implementation of ToString() seems to violate them.
+ SCOPED_TRACE(testcase.ToString());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(testcase.pad_value)));
+ auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01
+ "input"));
+ PaddingConfig padding_config =
+ ParsePaddingConfig(testcase.padding).ValueOrDie();
+ auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
+ padding_config)
+ .ValueOrDie(),
+ filter, pad_value, padding_config));
+
+ auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0,
+ ShapeUtil::MakeShape(
+ F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01
+ "input"));
+
+ ConvolutionDimensionNumbers dnums =
+ ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
+ Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
+ rhs_pad->shape().dimensions(2),
+ rhs_pad->shape().dimensions(3),
+ testcase.orig_conv_window))
+ .ValueOrDie();
+ auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
+ window, dnums)
+ .ValueOrDie(),
+ input, rhs_pad, window, dnums));
+
+ // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
+ // after the transformation.
+ PrecisionConfigProto precision_config;
+ precision_config.add_operand_precision(PrecisionConfigProto::HIGH);
+ precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST);
+ orig_conv->set_precision_config(precision_config);
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ if (testcase.expected_conv_window.empty()) {
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
+ } else {
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+ auto* conv = module->entry_computation()->root_instruction();
+ SCOPED_TRACE(module->ToString());
+ ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter()));
+ EXPECT_EQ(window_util::ToString(conv->window()),
+ absl::StrFormat("size=%dx%d %s",
+ conv->operand(1)->shape().dimensions(2),
+ conv->operand(1)->shape().dimensions(3),
+ testcase.expected_conv_window));
+ EXPECT_THAT(
+ conv->precision_config().operand_precision(),
+ ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST));
+ }
+}
+
TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
struct ConvTestOptions {
int in_batch = 10;