diff options
author | 2018-08-30 20:38:46 -0700 | |
---|---|---|
committer | 2018-08-30 20:46:31 -0700 | |
commit | 423633fc4fb2b9c75f6013c3ded8eca8fe06843d (patch) | |
tree | 0195723e75fa0ef2937b7bc930ae0f5abe9571e9 | |
parent | 06ea8fb214b1b859b211ded0bbe31726214ee3f2 (diff) |
[XLA] Merge kPad into kConvolution's window where possible.
This allows us to use e.g. cudnn's padding, instead of materializing a
kPad instruction.
PiperOrigin-RevId: 211028379
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 136 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 258 |
2 files changed, 392 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index a7a0044308..212ae97b59 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2213,7 +2213,141 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( .CloneToUnique())), {})); } + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> { + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { + return false; + } + + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); + } + } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; + }()); + + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> { + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; + } + + const auto& padding = rhs->padding_config(); + + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); + + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; + } + + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } + + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; + } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); + } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; + }()); + + if (folded_filter_pad) { + return Status::OK(); + } + if (!enable_conv_simplification_) { return Status::OK(); } @@ -2230,8 +2364,6 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( return Status::OK(); } - const ConvolutionDimensionNumbers& dnums = - convolution->convolution_dimension_numbers(); const Shape& input_shape = lhs->shape(); const Shape& filter_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 182c581ad8..b4ff048db0 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2123,6 +2123,264 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); } +// Used for TEST_Ps that test merging (or not) of a kPad instruction into a +// convolution's Window. +struct ConvPaddingTestcase { + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window) + : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window, + /*pad_value=*/0) {} + + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window, float pad_value) + : padding(padding), + orig_conv_window(orig_conv_window), + expected_conv_window(expected_conv_window), + pad_value(pad_value) {} + + string ToString() const { + return absl::StrFormat( + "padding=%s, orig_conv_window=%s, expected_conv_window=%s, " + "pad_value=%f", + padding, orig_conv_window, expected_conv_window, pad_value); + } + + string padding; + string orig_conv_window; + string expected_conv_window; + float pad_value; +}; + +// ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a +// computation that does +// +// conv(pad(param0, padding=padding), param1), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvInputPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface<ConvPaddingTestcase> {}; + +INSTANTIATE_TEST_CASE_P( + ConvInputPaddingTestCases, ConvInputPaddingTest, + ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{ + // Merge this edge padding into the conv. + {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"}, + // Merge this edge padding with the conv's edge padding. + {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"}, + // Merge this interior-padded kPad with the unpadded conv. The 3x6 + // interior padding gets transformed to 4x7 conv lhs dilation. + {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"}, + // kPad has dilation on one dim, conv has it on the other; merge them. + {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"}, + // kPad has dilation and edge padding on one dim, conv has them on the + // other; merge them. + {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10", + "pad=0_1x3_0 lhs_dilate=2x10"}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1}, + + // We refuse to transform the following because on some dimension, one + // of the kPad and conv has dilation and the other has some sort of + // padding. + {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + + // We can't merge feature or batch padding into the conv. + {"1_0x0_0x0_0x0_0", "", ""}, + {"0_0x1_0x0_0x0_0", "", ""}, + })); + +TEST_P(ConvInputPaddingTest, DoTest) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01 + "input")); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(input->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + input, pad_value, padding_config)); + + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, + ShapeUtil::MakeShape( + F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = + ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window)) + .ValueOrDie(); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), + window, dnums) + .ValueOrDie(), + lhs_pad, filter, window, dnums)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrCat("size=3x3 ", testcase.expected_conv_window)); + } +} + +// ConvFilterPaddingTest (and its one associated TEST_P) checks that a +// computation that does +// +// conv(param0, pad(param1, padding=padding)), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvFilterPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface<ConvPaddingTestcase> {}; + +INSTANTIATE_TEST_CASE_P( + ConvFilterPaddingTestCases, ConvFilterPaddingTest, + ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{ + // Can only merge interior padding on the filter's spatial dimensions; + // all + // other paddings (edge padding and interior padding on the channel + // dims) + // should be rejected out of hand. + {"1_0_0x0_0_0x0_0x0_0", "", ""}, + {"0_1_0x0_0_0x0_0x0_0", "", ""}, + {"0_0_1x0_0_0x0_0x0_0", "", ""}, + {"0_0_0x1_0_0x0_0x0_0", "", ""}, + {"0_0_0x0_1_0x0_0x0_0", "", ""}, + {"0_0_0x0_0_1x0_0x0_0", "", ""}, + {"0_0_0x0_0_0x1_0x0_0", "", ""}, + {"0_0_0x0_0_0x0_1x0_0", "", ""}, + {"0_0_0x0_0_0x0_0x1_0", "", ""}, + {"0_0_0x0_0_0x0_0x0_1", "", ""}, + + // Interior padding on channel dims can be merged into the conv, so long + // as the conv and pad don't have interior padding on the same dim. + {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"}, + {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"}, + {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"}, + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"}, + + // Can't merge if for a given dim there's interior padding on both the + // pad and conv. + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1}, + })); + +TEST_P(ConvFilterPaddingTest, DoIt) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01 + "input")); + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(filter->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + filter, pad_value, padding_config)); + + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeShape( + F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = ParseWindow(absl::StrFormat("size=%dx%d %s", + rhs_pad->shape().dimensions(2), + rhs_pad->shape().dimensions(3), + testcase.orig_conv_window)) + .ValueOrDie(); + auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + window, dnums) + .ValueOrDie(), + input, rhs_pad, window, dnums)); + + // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place + // after the transformation. + PrecisionConfigProto precision_config; + precision_config.add_operand_precision(PrecisionConfigProto::HIGH); + precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST); + orig_conv->set_precision_config(precision_config); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrFormat("size=%dx%d %s", + conv->operand(1)->shape().dimensions(2), + conv->operand(1)->shape().dimensions(3), + testcase.expected_conv_window)); + EXPECT_THAT( + conv->precision_config().operand_precision(), + ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST)); + } +} + TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { struct ConvTestOptions { int in_batch = 10; |