diff options
author | Justin Lebar <jlebar@google.com> | 2018-07-30 14:39:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 14:42:44 -0700 |
commit | 815f0b329ab362f652701a8b4f56bc854f2424ca (patch) | |
tree | 80133fd3cae793a55bf25075f9c5ab5763d39af0 | |
parent | c57870c1156688e71f9ae001970c2e8c2146c7a1 (diff) |
[XLA:GPU] Pad the channel dims of f16 convs to multiples of 8 on Volta.
This lets us actually use Volta's tensor cores for such convs.
PiperOrigin-RevId: 206649341
5 files changed, 484 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a73a341fdb..e0aae3866b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -545,6 +545,38 @@ cc_library( ) cc_library( + name = "pad_for_tensor_cores", + srcs = ["pad_for_tensor_cores.cc"], + hdrs = ["pad_for_tensor_cores.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:shape_inference", + ], +) + +tf_cc_test( + name = "pad_for_tensor_cores_test", + srcs = ["pad_for_tensor_cores_test.cc"], + deps = [ + ":ir_emission_utils", + ":pad_for_tensor_cores", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + ], +) + +cc_library( name = "gpu_transfer_manager", srcs = ["gpu_transfer_manager.cc"], hdrs = ["gpu_transfer_manager.h"], @@ -588,9 +620,11 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", + ":pad_for_tensor_cores", ":pad_insertion", ":partition_assignment", ":stream_assignment", + ":stream_executor_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 6d8996dac1..7a683ede54 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -52,9 +52,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -199,6 +201,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker<HloVerifier>(); pipeline.AddPass<CudnnConvolutionRewriter>(); pipeline.AddPass<PadInsertion>(); + if (IsVoltaOrLater(*stream_exec)) { + pipeline.AddPass<PadForTensorCores>(); + // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element + // pairs that TupleSimplifier fixes. + pipeline.AddPass<TupleSimplifier>(); + } TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc new file mode 100644 index 0000000000..79f7d31816 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -0,0 +1,233 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace gpu { + +using tensorflow::gtl::ArraySlice; + +// We want the input/output feature counts of an f16 conv to be factors of 8, +// because without this cudnn can't use tensor cores on the conv. +static constexpr int64 kDesiredNumFeaturesFactor = 8; + +// We won't pad a conv if doing so increases the total number of bytes in the +// lhs, rhs, or result by more than this amount. +// +// TODO(jlebar): This number was tuned experimentally. It represents a +// compromise on our current benchmarks; it speeds some up significantly, and +// doesn't slow any down. But we can observe by changing this value that +// there's additional room for speedups. Achieving those speedups without also +// slowing other things down will likely require a more sophisticated heuristic, +// possibly some form of auto-tuning. +static constexpr double kMaxBytesTouchedIncrease = 1.2; + +// Pads the given dimensions in the given shape up to a multiple of +// kDesiredNumFeaturesFactor. +static Shape PadShape(Shape s, ArraySlice<int64> dims) { + for (int64 dim : dims) { + int64 dim_to_pad_size = s.dimensions(dim); + int64 new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + s.set_dimensions(dim, new_dim_to_pad_size); + } + return s; +} + +// Creates and returns an HLO that zero-pads one or more dimensions in the given +// instruction so that its shape is equal to the given shape. +// +// Padding is added to the end of each relevant dimension. +// +// If the instruction already has the given shape, simply returns it without an +// intervening pad. +static HloInstruction* PadInstruction(HloInstruction* instr, + const Shape& new_shape) { + HloComputation* comp = instr->parent(); + + const Shape& shape = instr->shape(); + auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + + PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + + bool added_padding = false; + for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + if (shape.dimensions(dim) == new_shape.dimensions(dim)) { + continue; + } + CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim)); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + new_shape.dimensions(dim) - shape.dimensions(dim)); + added_padding = true; + } + + if (!added_padding) { + return instr; + } + return comp->AddInstruction( + HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); +} + +// Pads the input/output feature dimensions of the given cudnn convolution +// custom-call to be multiples of kDesiredNumFeaturesFactor. +static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) { + CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) + << "conv must use 0 scratch bytes, i.e. this pass must be run " + "before CudnnConvolutionAlgorithmPicker."; + + const auto& target = conv->custom_call_target(); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + Shape new_lhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardFilterCallTarget) { + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + }(); + + Shape new_rhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget) { + // RHS is "filter". + return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + }(); + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + Shape new_result_shape = [&] { + if (target == kCudnnConvForwardCallTarget) { + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + } + if (target == kCudnnConvBackwardInputCallTarget) { + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // Result is "filter". + return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + }(); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast<double>(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + + // OK, let's do the transformation! + + auto* new_lhs = PadInstruction(lhs, new_lhs_shape); + auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + CHECK(new_lhs != lhs || new_rhs != rhs) + << "We should have had to pad either LHS or RHS."; + + auto add = [&](std::unique_ptr<HloInstruction> new_instr) { + return conv->parent()->AddInstruction(std::move(new_instr)); + }; + + Shape new_conv_shape = ShapeUtil::MakeTupleShape( + {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); + auto* new_conv = + add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); + + // Slice the new conv result if necessary, keeping in mind that new_conv has + // tuple shape (new_result_shape, u8[0]). + if (!ShapeUtil::Equal(result_shape, new_result_shape)) { + std::vector<int64> start_indices(result_shape.dimensions_size(), 0); + std::vector<int64> end_indices(result_shape.dimensions().begin(), + result_shape.dimensions().end()); + std::vector<int64> strides(result_shape.dimensions_size(), 1); + + auto* new_conv_result = add( + HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); + auto* empty_temp_buffer = + add(HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint8>({}))); + auto* sliced_result = add(HloInstruction::CreateSlice( + result_shape, new_conv_result, start_indices, end_indices, strides)); + new_conv = + add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer})); + } + + VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " + << new_conv->ToString(); + TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); + return true; +} + +static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) { + std::vector<HloInstruction*> convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsCustomCallToDnnConvolution(*instr) && + instr->operand(0)->shape().element_type() == F16) { + convs.push_back(instr); + } + } + return convs; +} + +StatusOr<bool> PadForTensorCores::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h new file mode 100644 index 0000000000..192359f026 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Ensures that f16 cudnn convolutions have input/output channel dimensions that +// are multiples of 8, inserting pads/slices as necessary. +// +// This is useful primarily for Volta and newer GPUs, where tensor cores can +// only be used if the channel dims are multiples of 8. It's probably the +// opposite of useful on other GPUs, so you should check what GPU you're +// targeting before running this pass. +// +// TODO(jlebar): Also pad dots. +class PadForTensorCores : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "pad for tensor cores"; + } + + StatusOr<bool> Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc new file mode 100644 index 0000000000..99e7580b82 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +using PadForTensorCoresTest = HloVerifiedTestBase; + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + filter = f16[2,2,41,40] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + + SCOPED_TRACE(module().ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,41] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); +} + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvForwardCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _)); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,40] parameter(0) + filter = f16[2,2,41,40] parameter(1) + result = (f16[10,20,30,41], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardInputCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + output = f16[10,20,30,40] parameter(1) + result = (f16[2,2,41,40], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Pad(op::Parameter(0), _), op::Parameter(1)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + output = f16[10,20,30,41] parameter(1) + result = (f16[2,2,40,41], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Parameter(0), op::Pad(op::Parameter(1), _)))), + _))); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla |