diff options
author | Tim Shen <timshen@google.com> | 2018-09-25 16:05:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 16:13:15 -0700 |
commit | 22776289fbe30ca7f4b1a80d7e23f5bddca391c2 (patch) | |
tree | 0787a51a81e7603d2204ac42f9fee54261f6d996 /tensorflow/compiler/xla/service/gpu | |
parent | e62cd643839d264659285a273bcf34df1057136e (diff) |
Add a new pass after convolution rewriter and pad insertion, to pattern
match convolution forward + relu.
PiperOrigin-RevId: 214521083
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu')
7 files changed, 635 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 2775527e0c..51968d13d4 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -655,6 +655,7 @@ cc_library( deps = [ ":cudnn_convolution_algorithm_picker", ":cudnn_convolution_rewriter", + ":cudnn_fused_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -967,3 +968,19 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "cudnn_fused_convolution_rewriter", + srcs = ["cudnn_fused_convolution_rewriter.cc"], + hdrs = ["cudnn_fused_convolution_rewriter.h"], + deps = [ + ":backend_configs", + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:stream_executor_no_cuda", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc new file mode 100644 index 0000000000..3761c19cfc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc @@ -0,0 +1,278 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { +namespace { + +// Describes a matched pattern: +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// Where side_input has the shape of output buffer, and bias is a 1D array with +// the dimension of number of output features. +struct ConvWithRelu { + HloInstruction* maximum; + HloCustomCallInstruction* conv; + HloInstruction* bias; + HloInstruction* side_input; + HloConstantInstruction* alpha_conv; + HloConstantInstruction* alpha_side_input; +}; + +absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Broadcast; + using match::Constant; + using match::GetTupleElement; + using match::Maximum; + using match::MultiplyAnyOrder; + using match::Op; + + // The pattern we want to match: + // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); + // + // With its variants involving commute/reassociation of adds, multiplies, and + // max, and omission of alpha1, side_input, alpha2, or bias. + + HloInstruction* relu_input; + + // Match max(0, relu_input). + auto zero_pattern = Broadcast(match::ConstantScalar(0)); + if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && + !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { + return absl::nullopt; + } + HloInstruction* conv_instr = nullptr; + HloInstruction* alpha_conv_instr = nullptr; + HloInstruction* alpha_side_input_instr = nullptr; + HloInstruction* bias_broadcast_instr = nullptr; + HloInstruction* bias = nullptr; + HloInstruction* side_input = nullptr; + + // These nodes will not be in the returned value, but we need to check them + // for single use. + HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr, + *mul1 = nullptr, *mul2 = nullptr; + + const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); + const auto conv_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); + auto conv_pattern = GetTupleElement( + >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); + return AnyOf<HloInstruction>( + MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); + }(); + const auto side_input_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); + // If bias is already matched, match arbitrary additional input as side + // input. Note this may force a cheap operation (e.g. broadcast) to be + // materialized into a large buffer, as large as the output buffer. + // + // TODO(timshen): If in practice there are significant false positives, we + // should fix it. + auto side_input_pattern = Op(&side_input); + return AnyOf<HloInstruction>( + MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern), + side_input_pattern); + }(); + + { + // Try to match any of the following form of add, in any association: + // addends[0] + // addends[0] + addends[1] + // addends[0] + addends[1] + addends[2] + // + // Then try to match each addend with one of the three patterns: bias, conv, + // or side_input. Notice that side_input matching must go last, as it + // also matches a conv or a bias. + HloInstruction* addends[3] = {nullptr, nullptr, nullptr}; + auto add3_pattern = [&] { + auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1])); + return AnyOf<HloInstruction>( + AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern, + Op(&addends[0])); + }(); + CHECK(Match(relu_input, add3_pattern)); + for (auto addend : addends) { + if (addend) { + if (bias == nullptr && Match(addend, bias_pattern)) { + CHECK(bias); + } else if (conv_instr == nullptr && Match(addend, conv_pattern)) { + CHECK(conv_instr); + } else if (side_input == nullptr && Match(addend, side_input_pattern)) { + CHECK(side_input); + } else { + return absl::nullopt; + } + } + } + } + + if (conv_instr == nullptr) { + return absl::nullopt; + } + + for (HloInstruction* instr : + {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) { + if (instr && instr->user_count() > 1) { + return absl::nullopt; + } + } + + auto conv = Cast<HloCustomCallInstruction>(conv_instr); + auto bias_broadcast = + CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr); + + if (conv->custom_call_target() != kCudnnConvForwardCallTarget) { + return absl::nullopt; + } + + if (bias_broadcast) { + // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. + if (bias_broadcast_instr->dimensions().size() != 1) { + return absl::nullopt; + } + if (bias_broadcast_instr->dimensions(0) != + conv->convolution_dimension_numbers().output_feature_dimension()) { + return absl::nullopt; + } + } + + return ConvWithRelu{ + instr, + conv, + bias, + side_input, + CastOrNull<HloConstantInstruction>(alpha_conv_instr), + CastOrNull<HloConstantInstruction>(alpha_side_input_instr)}; +} + +StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu( + ConvWithRelu match) { + auto conv = match.conv; + + HloComputation* computation = conv->parent(); + PrimitiveType element_type = conv->operand(0)->shape().element_type(); + + const auto get_alpha_value = + [](HloConstantInstruction* instr) -> StatusOr<double> { + TF_ASSIGN_OR_RETURN( + auto alpha, + Cast<HloConstantInstruction>(instr)->literal().Convert(F64)); + return alpha.GetFirstElement<double>(); + }; + + double alpha_conv = 1; + if (match.alpha_conv) { + TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv)); + } + + double alpha_side_input; + if (match.side_input) { + if (match.alpha_side_input) { + TF_ASSIGN_OR_RETURN(alpha_side_input, + get_alpha_value(match.alpha_side_input)); + } else { + alpha_side_input = 1; + } + } else { + CHECK(match.alpha_side_input == nullptr); + alpha_side_input = 0; + } + + auto bias = match.bias; + if (!bias) { + auto zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + + int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( + conv->convolution_dimension_numbers().output_feature_dimension()); + bias = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShapeWithDescendingLayout(element_type, + {num_output_feature}), + zero, {})); + } + + CHECK(bias); + std::vector<HloInstruction*> args = {conv->mutable_operand(0), + conv->mutable_operand(1), bias}; + if (match.side_input) { + args.push_back(match.side_input); + } + auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( + conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); + new_conv->set_window(conv->window()); + new_conv->set_convolution_dimension_numbers( + conv->convolution_dimension_numbers()); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config<CudnnConvBackendConfig>()); + config.set_activation_mode( + static_cast<int64>(se::dnn::ActivationMode::kRelu)); + config.set_conv_result_scale(alpha_conv); + config.set_side_input_scale(alpha_side_input); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + + VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name(); + return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), + new_conv, 0); +} + +} // namespace + +StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector<ConvWithRelu> matches; + int num_forward_convs = 0; + for (auto instr : computation->instructions()) { + auto match = FindConvWithRelu(instr); + if (match.has_value()) { + matches.push_back(*match); + } + if (auto call = DynCast<HloCustomCallInstruction>(instr)) { + if (call->custom_call_target() == kCudnnConvForwardCallTarget) { + num_forward_convs++; + } + } + } + VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size() + << " out of " << num_forward_convs << " forward convs."; + std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>> + replacements; + for (const ConvWithRelu& match : matches) { + TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match)); + replacements.push_back({match.maximum, std::move(new_instr)}); + changed = true; + } + for (auto& replacement : replacements) { + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + replacement.first, std::move(replacement.second))); + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h new file mode 100644 index 0000000000..bd12aadded --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +class CudnnFusedConvolutionRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-fused-convolution-rewriter"; + } + + StatusOr<bool> Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 01a18f4f8e..0b3b429710 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass<CudnnConvolutionRewriter>(); + pipeline.AddPass<CudnnFusedConvolutionRewriter>(); pipeline.AddPass<PadInsertion>(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass<PadForTensorCores>(); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 7e77dc9ac6..b42a19e3a2 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -30,7 +30,8 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); + CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || + conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -385,7 +386,8 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) { } for (HloInstruction* instruction : convs) { const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget) { changed |= CanonicalizeForwardConvolution(instruction); } else if (target == kCudnnConvBackwardFilterCallTarget) { changed |= CanonicalizeBackwardFilterConvolution(instruction); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 5da6f232d5..a725533567 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -209,3 +209,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cudnn_fused_convolution_rewriter_test", + srcs = ["cudnn_fused_convolution_rewriter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc new file mode 100644 index 0000000000..5632cac186 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc @@ -0,0 +1,283 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class CudnnFusedConvolutionRewriterTest : public HloTestBase { + protected: + string GetOptimizedHlo(absl::string_view hlo_string) { + return backend() + .compiler() + ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + } + + void TestMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_EQ(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convForward")) + << optimized_hlo_string; + EXPECT_NE(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convBiasActivationForward")) + << optimized_hlo_string; + EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) + << optimized_hlo_string; + } + } + + void TestNotMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + string optimized_hlo = GetOptimizedHlo(hlo_with_new_type); + EXPECT_NE(absl::string_view::npos, + optimized_hlo.find("__cudnn$convForward")) + << optimized_hlo; + EXPECT_EQ(absl::string_view::npos, + optimized_hlo.find("__cudnn$convBiasActivationForward")) + << optimized_hlo; + } + } +}; + +TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) { + // max(0, conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) { + // max(0, conv(x, w) + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) { + // max(0, conv(x, w) + side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) { + // max(0, conv(x, w) + side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) { + // max(0, 0.999994934 * conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} + scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv) + ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) { + // max(0, conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, + TestScaledConvAndScaledSideInputWithBias) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) { + // max(0.1, conv(x, w)) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + point_one = TYPE[] constant(0.1) + point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) { + // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input1 = TYPE[1,3,3,64] parameter(2) + side_input2 = TYPE[1,3,3,64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input2) + add2 = TYPE[1,3,3,64] add(add1, side_input1) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +} // namespace +} // namespace gpu +} // namespace xla |