aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-25 16:05:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 16:13:15 -0700
commit22776289fbe30ca7f4b1a80d7e23f5bddca391c2 (patch)
tree0787a51a81e7603d2204ac42f9fee54261f6d996 /tensorflow/compiler/xla/service/gpu
parente62cd643839d264659285a273bcf34df1057136e (diff)
Add a new pass after convolution rewriter and pad insertion, to pattern
match convolution forward + relu. PiperOrigin-RevId: 214521083
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu')
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD17
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc278
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h37
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc283
7 files changed, 635 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 2775527e0c..51968d13d4 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -655,6 +655,7 @@ cc_library(
deps = [
":cudnn_convolution_algorithm_picker",
":cudnn_convolution_rewriter",
+ ":cudnn_fused_convolution_rewriter",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
@@ -967,3 +968,19 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "cudnn_fused_convolution_rewriter",
+ srcs = ["cudnn_fused_convolution_rewriter.cc"],
+ hdrs = ["cudnn_fused_convolution_rewriter.h"],
+ deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
new file mode 100644
index 0000000000..3761c19cfc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
@@ -0,0 +1,278 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Describes a matched pattern:
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+// Where side_input has the shape of output buffer, and bias is a 1D array with
+// the dimension of number of output features.
+struct ConvWithRelu {
+ HloInstruction* maximum;
+ HloCustomCallInstruction* conv;
+ HloInstruction* bias;
+ HloInstruction* side_input;
+ HloConstantInstruction* alpha_conv;
+ HloConstantInstruction* alpha_side_input;
+};
+
+absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::GetTupleElement;
+ using match::Maximum;
+ using match::MultiplyAnyOrder;
+ using match::Op;
+
+ // The pattern we want to match:
+ // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+ //
+ // With its variants involving commute/reassociation of adds, multiplies, and
+ // max, and omission of alpha1, side_input, alpha2, or bias.
+
+ HloInstruction* relu_input;
+
+ // Match max(0, relu_input).
+ auto zero_pattern = Broadcast(match::ConstantScalar(0));
+ if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
+ !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
+ return absl::nullopt;
+ }
+ HloInstruction* conv_instr = nullptr;
+ HloInstruction* alpha_conv_instr = nullptr;
+ HloInstruction* alpha_side_input_instr = nullptr;
+ HloInstruction* bias_broadcast_instr = nullptr;
+ HloInstruction* bias = nullptr;
+ HloInstruction* side_input = nullptr;
+
+ // These nodes will not be in the returned value, but we need to check them
+ // for single use.
+ HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
+ *mul1 = nullptr, *mul2 = nullptr;
+
+ const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
+ const auto conv_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
+ auto conv_pattern = GetTupleElement(
+ &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
+ }();
+ const auto side_input_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
+ // If bias is already matched, match arbitrary additional input as side
+ // input. Note this may force a cheap operation (e.g. broadcast) to be
+ // materialized into a large buffer, as large as the output buffer.
+ //
+ // TODO(timshen): If in practice there are significant false positives, we
+ // should fix it.
+ auto side_input_pattern = Op(&side_input);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
+ side_input_pattern);
+ }();
+
+ {
+ // Try to match any of the following form of add, in any association:
+ // addends[0]
+ // addends[0] + addends[1]
+ // addends[0] + addends[1] + addends[2]
+ //
+ // Then try to match each addend with one of the three patterns: bias, conv,
+ // or side_input. Notice that side_input matching must go last, as it
+ // also matches a conv or a bias.
+ HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
+ auto add3_pattern = [&] {
+ auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
+ return AnyOf<HloInstruction>(
+ AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
+ Op(&addends[0]));
+ }();
+ CHECK(Match(relu_input, add3_pattern));
+ for (auto addend : addends) {
+ if (addend) {
+ if (bias == nullptr && Match(addend, bias_pattern)) {
+ CHECK(bias);
+ } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
+ CHECK(conv_instr);
+ } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
+ CHECK(side_input);
+ } else {
+ return absl::nullopt;
+ }
+ }
+ }
+ }
+
+ if (conv_instr == nullptr) {
+ return absl::nullopt;
+ }
+
+ for (HloInstruction* instr :
+ {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
+ if (instr && instr->user_count() > 1) {
+ return absl::nullopt;
+ }
+ }
+
+ auto conv = Cast<HloCustomCallInstruction>(conv_instr);
+ auto bias_broadcast =
+ CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
+
+ if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
+ return absl::nullopt;
+ }
+
+ if (bias_broadcast) {
+ // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
+ if (bias_broadcast_instr->dimensions().size() != 1) {
+ return absl::nullopt;
+ }
+ if (bias_broadcast_instr->dimensions(0) !=
+ conv->convolution_dimension_numbers().output_feature_dimension()) {
+ return absl::nullopt;
+ }
+ }
+
+ return ConvWithRelu{
+ instr,
+ conv,
+ bias,
+ side_input,
+ CastOrNull<HloConstantInstruction>(alpha_conv_instr),
+ CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
+}
+
+StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
+ ConvWithRelu match) {
+ auto conv = match.conv;
+
+ HloComputation* computation = conv->parent();
+ PrimitiveType element_type = conv->operand(0)->shape().element_type();
+
+ const auto get_alpha_value =
+ [](HloConstantInstruction* instr) -> StatusOr<double> {
+ TF_ASSIGN_OR_RETURN(
+ auto alpha,
+ Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
+ return alpha.GetFirstElement<double>();
+ };
+
+ double alpha_conv = 1;
+ if (match.alpha_conv) {
+ TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
+ }
+
+ double alpha_side_input;
+ if (match.side_input) {
+ if (match.alpha_side_input) {
+ TF_ASSIGN_OR_RETURN(alpha_side_input,
+ get_alpha_value(match.alpha_side_input));
+ } else {
+ alpha_side_input = 1;
+ }
+ } else {
+ CHECK(match.alpha_side_input == nullptr);
+ alpha_side_input = 0;
+ }
+
+ auto bias = match.bias;
+ if (!bias) {
+ auto zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+
+ int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions(
+ conv->convolution_dimension_numbers().output_feature_dimension());
+ bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShapeWithDescendingLayout(element_type,
+ {num_output_feature}),
+ zero, {}));
+ }
+
+ CHECK(bias);
+ std::vector<HloInstruction*> args = {conv->mutable_operand(0),
+ conv->mutable_operand(1), bias};
+ if (match.side_input) {
+ args.push_back(match.side_input);
+ }
+ auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
+ conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+ new_conv->set_window(conv->window());
+ new_conv->set_convolution_dimension_numbers(
+ conv->convolution_dimension_numbers());
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ config.set_activation_mode(
+ static_cast<int64>(se::dnn::ActivationMode::kRelu));
+ config.set_conv_result_scale(alpha_conv);
+ config.set_side_input_scale(alpha_side_input);
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
+
+ VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name();
+ return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
+ new_conv, 0);
+}
+
+} // namespace
+
+StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithRelu> matches;
+ int num_forward_convs = 0;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithRelu(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+ if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
+ num_forward_convs++;
+ }
+ }
+ }
+ VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
+ << " out of " << num_forward_convs << " forward convs.";
+ std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
+ replacements;
+ for (const ConvWithRelu& match : matches) {
+ TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
+ replacements.push_back({match.maximum, std::move(new_instr)});
+ changed = true;
+ }
+ for (auto& replacement : replacements) {
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ replacement.first, std::move(replacement.second)));
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
new file mode 100644
index 0000000000..bd12aadded
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedConvolutionRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "cudnn-fused-convolution-rewriter";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 01a18f4f8e..0b3b429710 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
@@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
+ pipeline.AddPass<CudnnFusedConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 7e77dc9ac6..b42a19e3a2 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -30,7 +30,8 @@ namespace gpu {
namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
- CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
+ CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
+ conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
return window_util::HasSymmetricPadding(conv.window()) &&
!window_util::HasNegativePadding(conv.window()) &&
!window_util::HasDilation(conv.window());
@@ -385,7 +386,8 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
}
for (HloInstruction* instruction : convs) {
const auto& target = instruction->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget) {
changed |= CanonicalizeForwardConvolution(instruction);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
changed |= CanonicalizeBackwardFilterConvolution(instruction);
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 5da6f232d5..a725533567 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -209,3 +209,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cudnn_fused_convolution_rewriter_test",
+ srcs = ["cudnn_fused_convolution_rewriter_test.cc"],
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":gpu_codegen_test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
new file mode 100644
index 0000000000..5632cac186
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
@@ -0,0 +1,283 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/str_replace.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CudnnFusedConvolutionRewriterTest : public HloTestBase {
+ protected:
+ string GetOptimizedHlo(absl::string_view hlo_string) {
+ return backend()
+ .compiler()
+ ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest())
+ .ConsumeValueOrDie(),
+ backend().default_stream_executor(),
+ backend().memory_allocator())
+ .ConsumeValueOrDie()
+ ->ToString();
+ }
+
+ void TestMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convForward"))
+ << optimized_hlo_string;
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo_string;
+ EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
+ << optimized_hlo_string;
+ }
+ }
+
+ void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ string optimized_hlo = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convForward"))
+ << optimized_hlo;
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo;
+ }
+ }
+};
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) {
+ // max(0, conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) {
+ // max(0, conv(x, w) + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) {
+ // max(0, conv(x, w) + side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) {
+ // max(0, conv(x, w) + side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) {
+ // max(0, 0.999994934 * conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) {
+ // max(0, conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest,
+ TestScaledConvAndScaledSideInputWithBias) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) {
+ // max(0.1, conv(x, w)) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ point_one = TYPE[] constant(0.1)
+ point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) {
+ // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input1 = TYPE[1,3,3,64] parameter(2)
+ side_input2 = TYPE[1,3,3,64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input2)
+ add2 = TYPE[1,3,3,64] add(add1, side_input1)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla