aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-07-30 14:39:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 14:42:44 -0700
commit815f0b329ab362f652701a8b4f56bc854f2424ca (patch)
tree80133fd3cae793a55bf25075f9c5ab5763d39af0
parentc57870c1156688e71f9ae001970c2e8c2146c7a1 (diff)
[XLA:GPU] Pad the channel dims of f16 convs to multiples of 8 on Volta.
This lets us actually use Volta's tensor cores for such convs. PiperOrigin-RevId: 206649341
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD34
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc233
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h45
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc164
5 files changed, 484 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a73a341fdb..e0aae3866b 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -545,6 +545,38 @@ cc_library(
)
cc_library(
+ name = "pad_for_tensor_cores",
+ srcs = ["pad_for_tensor_cores.cc"],
+ hdrs = ["pad_for_tensor_cores.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_creation_utils",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:shape_inference",
+ ],
+)
+
+tf_cc_test(
+ name = "pad_for_tensor_cores_test",
+ srcs = ["pad_for_tensor_cores_test.cc"],
+ deps = [
+ ":ir_emission_utils",
+ ":pad_for_tensor_cores",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
+ ],
+)
+
+cc_library(
name = "gpu_transfer_manager",
srcs = ["gpu_transfer_manager.cc"],
hdrs = ["gpu_transfer_manager.h"],
@@ -588,9 +620,11 @@ cc_library(
":ir_emission_utils",
":ir_emitter",
":multi_output_fusion",
+ ":pad_for_tensor_cores",
":pad_insertion",
":partition_assignment",
":stream_assignment",
+ ":stream_executor_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 6d8996dac1..7a683ede54 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -52,9 +52,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -199,6 +201,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<CudnnConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
+ if (IsVoltaOrLater(*stream_exec)) {
+ pipeline.AddPass<PadForTensorCores>();
+ // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element
+ // pairs that TupleSimplifier fixes.
+ pipeline.AddPass<TupleSimplifier>();
+ }
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
new file mode 100644
index 0000000000..79f7d31816
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -0,0 +1,233 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+
+namespace xla {
+namespace gpu {
+
+using tensorflow::gtl::ArraySlice;
+
+// We want the input/output feature counts of an f16 conv to be factors of 8,
+// because without this cudnn can't use tensor cores on the conv.
+static constexpr int64 kDesiredNumFeaturesFactor = 8;
+
+// We won't pad a conv if doing so increases the total number of bytes in the
+// lhs, rhs, or result by more than this amount.
+//
+// TODO(jlebar): This number was tuned experimentally. It represents a
+// compromise on our current benchmarks; it speeds some up significantly, and
+// doesn't slow any down. But we can observe by changing this value that
+// there's additional room for speedups. Achieving those speedups without also
+// slowing other things down will likely require a more sophisticated heuristic,
+// possibly some form of auto-tuning.
+static constexpr double kMaxBytesTouchedIncrease = 1.2;
+
+// Pads the given dimensions in the given shape up to a multiple of
+// kDesiredNumFeaturesFactor.
+static Shape PadShape(Shape s, ArraySlice<int64> dims) {
+ for (int64 dim : dims) {
+ int64 dim_to_pad_size = s.dimensions(dim);
+ int64 new_dim_to_pad_size =
+ RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+ s.set_dimensions(dim, new_dim_to_pad_size);
+ }
+ return s;
+}
+
+// Creates and returns an HLO that zero-pads one or more dimensions in the given
+// instruction so that its shape is equal to the given shape.
+//
+// Padding is added to the end of each relevant dimension.
+//
+// If the instruction already has the given shape, simply returns it without an
+// intervening pad.
+static HloInstruction* PadInstruction(HloInstruction* instr,
+ const Shape& new_shape) {
+ HloComputation* comp = instr->parent();
+
+ const Shape& shape = instr->shape();
+ auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+
+ PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
+
+ bool added_padding = false;
+ for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) {
+ if (shape.dimensions(dim) == new_shape.dimensions(dim)) {
+ continue;
+ }
+ CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim));
+ pad_config.mutable_dimensions(dim)->set_edge_padding_high(
+ new_shape.dimensions(dim) - shape.dimensions(dim));
+ added_padding = true;
+ }
+
+ if (!added_padding) {
+ return instr;
+ }
+ return comp->AddInstruction(
+ HloInstruction::CreatePad(new_shape, instr, zero, pad_config));
+}
+
+// Pads the input/output feature dimensions of the given cudnn convolution
+// custom-call to be multiples of kDesiredNumFeaturesFactor.
+static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
+ CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
+ << "conv must use 0 scratch bytes, i.e. this pass must be run "
+ "before CudnnConvolutionAlgorithmPicker.";
+
+ const auto& target = conv->custom_call_target();
+ const auto& dnums = conv->convolution_dimension_numbers();
+ auto* lhs = conv->mutable_operand(0);
+ auto* rhs = conv->mutable_operand(1);
+ const Shape& result_shape = conv->shape().tuple_shapes(0);
+
+ Shape new_lhs_shape = [&] {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBackwardFilterCallTarget) {
+ // LHS is "input".
+ return PadShape(lhs->shape(), {dnums.input_feature_dimension()});
+ }
+ CHECK_EQ(target, kCudnnConvBackwardInputCallTarget);
+ // LHS is "output".
+ return PadShape(lhs->shape(), {dnums.output_feature_dimension()});
+ }();
+
+ Shape new_rhs_shape = [&] {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBackwardInputCallTarget) {
+ // RHS is "filter".
+ return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ }
+ CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
+ // RHS is "output".
+ return PadShape(rhs->shape(), {dnums.output_feature_dimension()});
+ }();
+
+ if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
+ ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) {
+ VLOG(3) << "No need to pad features of " << conv->ToString();
+ return false;
+ }
+
+ Shape new_result_shape = [&] {
+ if (target == kCudnnConvForwardCallTarget) {
+ // Result is "output".
+ return PadShape(result_shape, {dnums.output_feature_dimension()});
+ }
+ if (target == kCudnnConvBackwardInputCallTarget) {
+ // Result is "input".
+ return PadShape(result_shape, {dnums.input_feature_dimension()});
+ }
+ CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
+ // Result is "filter".
+ return PadShape(result_shape, {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ }();
+
+ // Check that padding wouldn't increase the total bytes read/written by this
+ // operation too much.
+ auto check_size_increase = [&](const Shape& old_shape,
+ const Shape& new_shape) {
+ int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape);
+ int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape);
+ if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) {
+ return true;
+ }
+ VLOG(3) << "Not padding convolution; doing so would change input / result "
+ "shape from "
+ << ShapeUtil::HumanString(old_shape) << " to "
+ << ShapeUtil::HumanString(new_shape) << ", a size increase of "
+ << new_bytes / static_cast<double>(old_bytes) << "x > "
+ << kMaxBytesTouchedIncrease << "x: " << conv->ToString();
+ return false;
+ };
+ if (!check_size_increase(lhs->shape(), new_lhs_shape) ||
+ !check_size_increase(rhs->shape(), new_rhs_shape) ||
+ !check_size_increase(result_shape, new_result_shape)) {
+ return false;
+ }
+
+ // OK, let's do the transformation!
+
+ auto* new_lhs = PadInstruction(lhs, new_lhs_shape);
+ auto* new_rhs = PadInstruction(rhs, new_rhs_shape);
+ CHECK(new_lhs != lhs || new_rhs != rhs)
+ << "We should have had to pad either LHS or RHS.";
+
+ auto add = [&](std::unique_ptr<HloInstruction> new_instr) {
+ return conv->parent()->AddInstruction(std::move(new_instr));
+ };
+
+ Shape new_conv_shape = ShapeUtil::MakeTupleShape(
+ {new_result_shape, ShapeUtil::MakeShape(U8, {0})});
+ auto* new_conv =
+ add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs}));
+
+ // Slice the new conv result if necessary, keeping in mind that new_conv has
+ // tuple shape (new_result_shape, u8[0]).
+ if (!ShapeUtil::Equal(result_shape, new_result_shape)) {
+ std::vector<int64> start_indices(result_shape.dimensions_size(), 0);
+ std::vector<int64> end_indices(result_shape.dimensions().begin(),
+ result_shape.dimensions().end());
+ std::vector<int64> strides(result_shape.dimensions_size(), 1);
+
+ auto* new_conv_result = add(
+ HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0));
+ auto* empty_temp_buffer =
+ add(HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8>({})));
+ auto* sliced_result = add(HloInstruction::CreateSlice(
+ result_shape, new_conv_result, start_indices, end_indices, strides));
+ new_conv =
+ add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer}));
+ }
+
+ VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with "
+ << new_conv->ToString();
+ TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv));
+ return true;
+}
+
+static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
+ std::vector<HloInstruction*> convs;
+ for (HloInstruction* instr : comp->instructions()) {
+ if (IsCustomCallToDnnConvolution(*instr) &&
+ instr->operand(0)->shape().element_type() == F16) {
+ convs.push_back(instr);
+ }
+ }
+ return convs;
+}
+
+StatusOr<bool> PadForTensorCores::Run(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* comp : module->MakeNonfusionComputations()) {
+ for (HloInstruction* conv : GetRelevantConvs(comp)) {
+ TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv));
+ changed |= result;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
new file mode 100644
index 0000000000..192359f026
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
+
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+// Ensures that f16 cudnn convolutions have input/output channel dimensions that
+// are multiples of 8, inserting pads/slices as necessary.
+//
+// This is useful primarily for Volta and newer GPUs, where tensor cores can
+// only be used if the channel dims are multiples of 8. It's probably the
+// opposite of useful on other GPUs, so you should check what GPU you're
+// targeting before running this pass.
+//
+// TODO(jlebar): Also pad dots.
+class PadForTensorCores : public HloPassInterface {
+ public:
+ tensorflow::StringPiece name() const override {
+ return "pad for tensor cores";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
new file mode 100644
index 0000000000..99e7580b82
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+namespace op = xla::testing::opcode_matchers;
+using ::testing::_;
+
+using PadForTensorCoresTest = HloVerifiedTestBase;
+
+TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,41] parameter(0)
+ filter = f16[2,2,41,40] parameter(1)
+ ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+
+ SCOPED_TRACE(module().ToString());
+ EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget,
+ op::Pad(op::Parameter(0), _),
+ op::Pad(op::Parameter(1), _)));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(),
+ ShapeUtil::MakeShape(F16, {10, 20, 30, 48})));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(),
+ ShapeUtil::MakeShape(F16, {2, 2, 48, 40})));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ output = f16[10,20,30,41] parameter(0)
+ filter = f16[2,2,40,41] parameter(1)
+ ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardInput"
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget,
+ op::Pad(op::Parameter(0), _),
+ op::Pad(op::Parameter(1), _)));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(),
+ ShapeUtil::MakeShape(F16, {10, 20, 30, 48})));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(),
+ ShapeUtil::MakeShape(F16, {2, 2, 40, 48})));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,40] parameter(0)
+ filter = f16[2,2,40,41] parameter(1)
+ ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convForward"
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvForwardCallTarget, op::Parameter(0),
+ op::Pad(op::Parameter(1), _)))),
+ _));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ output = f16[10,20,30,40] parameter(0)
+ filter = f16[2,2,41,40] parameter(1)
+ result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardInput"
+ ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
+ op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvBackwardInputCallTarget, op::Parameter(0),
+ op::Pad(op::Parameter(1), _)))),
+ _)));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,41] parameter(0)
+ output = f16[10,20,30,40] parameter(1)
+ result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardFilter"
+ ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
+ op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvBackwardFilterCallTarget,
+ op::Pad(op::Parameter(0), _), op::Parameter(1)))),
+ _)));
+}
+
+TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) {
+ ParseAndVerifyModule(R"(
+ HloModule TestModule
+
+ ENTRY TestComputation {
+ input = f16[10,20,30,40] parameter(0)
+ output = f16[10,20,30,41] parameter(1)
+ result = (f16[2,2,40,41], u8[0]) custom-call(input, output),
+ window={size=2x2}, dim_labels=b01f_01io->b01f,
+ custom_call_target="__cudnn$convBackwardFilter"
+ ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0
+ })");
+ EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
+ auto* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
+ op::Slice(op::GetTupleElement(op::CustomCall(
+ kCudnnConvBackwardFilterCallTarget,
+ op::Parameter(0), op::Pad(op::Parameter(1), _)))),
+ _)));
+}
+
+} // anonymous namespace
+} // namespace gpu
+} // namespace xla