diff options
author | Tim Shen <timshen@google.com> | 2018-09-24 17:44:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 17:48:32 -0700 |
commit | 391cdd80952e9cc546d82a8bf2fe7dd04f46cb2f (patch) | |
tree | a253b5bd01d2088d07e1cae505028a600d834384 | |
parent | 9ab01c6732dae1143e22713375a9cc7758216787 (diff) |
Add cuDNN fused convolution forward support.
The tests are in the next patch.
PiperOrigin-RevId: 214362688
11 files changed, 169 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7231fd844e..2775527e0c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -433,6 +433,7 @@ cc_library( srcs = ["cudnn_convolution_rewriter.cc"], hdrs = ["cudnn_convolution_rewriter.h"], deps = [ + ":backend_configs", ":ir_emission_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", @@ -597,14 +598,11 @@ cc_library( hdrs = ["pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:shape_inference", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index 640c6392b8..78e14d860e 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -24,4 +24,18 @@ message CudnnConvBackendConfig { // true, cudnn may choose not to use tensor cores, e.g. because the GPU or // selected algorithm doesn't support it. bool tensor_ops_enabled = 2; + + // The scaling factor multiplied with the convolution result. + double conv_result_scale = 4; + + // Below are the fields related to cuDNN's fused convolution. Refer to + // CudnnConvParams for their meanings. + + // The requested activation (e.g. relu) after the convolution. It is with type + // stream_executor::dnn::ActivationMode. + int64 activation_mode = 3; + + // The scaling factor multiplied with the side input. If no side input buffer + // is provided, this field must be 0. + double side_input_scale = 5; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 391456576f..7125673887 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -89,6 +89,7 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind, succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms); break; case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: succ = stream_exec->GetConvolveAlgorithms(true, &algorithms); break; } @@ -363,8 +364,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( backend_config.set_tensor_ops_enabled(tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction( - instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0), - instr->mutable_operand(1)})); + instr->CloneWithNewOperands(new_call_shape, instr->operands())); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 2834d47412..ef29237301 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -476,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) { return std::make_tuple(true, new_window, dnums, rhs); } +CudnnConvBackendConfig GetDefaultBackendConfig() { + CudnnConvBackendConfig config; + config.set_conv_result_scale(1); + return config; +} + // Tries to rewrite a single convolution into a call to cudnn. StatusOr<bool> RunOnInstruction(HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); @@ -515,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { return false; } + TF_RETURN_IF_ERROR( + custom_call->set_backend_config(GetDefaultBackendConfig())); + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out // the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 32d67084b3..89dd1bb272 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -40,6 +40,25 @@ using se::dnn::FilterLayout; using se::dnn::ProfileResult; struct CudnnConvParams { + // Here are the fields related to cuDNN's fused convolution. The result thus + // is defined as: + // activation(conv_result_scale * conv(x, w) + + // side_input_scale * side_input + broadcast(bias)) + // + // The most common fused conv is conv forward + relu/identity, for example. + // + // bias_buf is a single-dimensional array, with the length equal to the number + // of output features. It'll be broadcasted to the output shape in order to be + // added to the final results. + // + // side_input_buf, if valid, must have the same shape as the output buffer. + struct FusionParams { + se::dnn::ActivationMode mode; + double side_input_scale; + se::DeviceMemoryBase bias_buf; + se::DeviceMemoryBase side_input_buf; // nullable + }; + CudnnConvKind kind; const Shape* input_shape; const Shape* filter_shape; @@ -51,6 +70,9 @@ struct CudnnConvParams { const ConvolutionDimensionNumbers* dnums; int64 feature_group_count; se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + absl::optional<FusionParams> fusion; }; // A StreamExecutor ScratchAllocator that wraps a single XLA allocation, @@ -202,23 +224,73 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params, switch (kind) { case CudnnConvKind::kForward: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveWithAlgorithm( input_descriptor, input_buf, filter_descriptor, filter_buf, convolution_descriptor, output_descriptor, &output_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardInput: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardDataWithAlgorithm( filter_descriptor, filter_buf, output_descriptor, output_buf, convolution_descriptor, input_descriptor, &input_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardFilter: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardFilterWithAlgorithm( input_descriptor, input_buf, output_descriptor, output_buf, convolution_descriptor, filter_descriptor, &filter_buf, scratch_allocator, algorithm, profile_result); break; + case CudnnConvKind::kForwardActivation: { + BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_layout(output_dl); + + se::DeviceMemory<T> side_input(params.fusion->side_input_buf); + // If there is no side input, use output as the side input. + if (side_input.is_null()) { + if (params.fusion->side_input_scale != 0) { + return InternalError( + "Side input scale is not 0, yet no side input buffer is " + "provided"); + } + // Since side-input scale is 0, the values in the side input don't + // matter. The simplest thing to do would be to pass in a null buffer + // for the side input, but cudnn doesn't allow this. cudnn does promise + // that if side-input-scale is 0 the side input won't be read, so we + // just pass in the output buffer, since it's handy and has the correct + // size. + side_input = output_buf; + } + + stream->ThenFusedConvolveWithAlgorithm( + input_descriptor, input_buf, params.conv_result_scale, + filter_descriptor, filter_buf, convolution_descriptor, side_input, + params.fusion->side_input_scale, bias_desc, + DeviceMemory<T>(params.fusion->bias_buf), params.fusion->mode, + output_descriptor, &output_buf, scratch_allocator, algorithm, + profile_result); + break; + } } if (!stream->ok()) { @@ -250,6 +322,7 @@ StatusOr<CudnnConvParams> GetCudnnConvParams( params.feature_group_count = conv->feature_group_count(); params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( backend_config.algorithm(), backend_config.tensor_ops_enabled())); + params.conv_result_scale = backend_config.conv_result_scale(); if (target == kCudnnConvForwardCallTarget) { params.kind = CudnnConvKind::kForward; @@ -275,6 +348,29 @@ StatusOr<CudnnConvParams> GetCudnnConvParams( params.input_buf = operand_buffers[0]; params.filter_buf = result_buffer; params.output_buf = operand_buffers[1]; + } else if (target == kCudnnConvBiasActivationForwardCallTarget) { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast<se::dnn::ActivationMode>( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } } else { return InternalError("Unexpected custom call target: %s", target); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 06314e413e..74352f26aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -104,6 +104,7 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr)); switch (kind) { case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: input_shape = &lhs_shape; filter_shape = &rhs_shape; output_shape = &result_shape; @@ -153,6 +154,20 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1)); TF_RETURN_IF_ERROR( constraints->SetBufferLayout(result_shape.layout(), *call_result_buf)); + // instr->operand(2), if exists, is the bias buffer. There is no need to + // assign layout to it, as it has only one dimension. + + // instr->opernad(3), if exists, is the side input buffer. + if (instr->operand_count() == 4) { + if (kind != CudnnConvKind::kForwardActivation) { + return InternalError( + "Invalid convolution. Conv has a side input, but kind is not fused " + "conv forward: %s", + instr->ToString()); + } + // The side input layout must match the output layout. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3)); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 76757faf60..ec3d8f9405 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -129,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget = "__cudnn$convBackwardInput"; const char* const kCudnnConvBackwardFilterCallTarget = "__cudnn$convBackwardFilter"; +const char* const kCudnnConvBiasActivationForwardCallTarget = + "__cudnn$convBiasActivationForward"; bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() != HloOpcode::kCustomCall) { @@ -137,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { const auto& target = hlo.custom_call_target(); return target == kCudnnConvForwardCallTarget || target == kCudnnConvBackwardInputCallTarget || - target == kCudnnConvBackwardFilterCallTarget; + target == kCudnnConvBackwardFilterCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget; } bool ImplementedAsLibraryCall(const HloInstruction& hlo) { @@ -247,6 +250,9 @@ StatusOr<CudnnConvKind> GetCudnnConvKind( if (target == kCudnnConvBackwardFilterCallTarget) { return CudnnConvKind::kBackwardFilter; } + if (target == kCudnnConvBiasActivationForwardCallTarget) { + return CudnnConvKind::kForwardActivation; + } return InternalError("Unexpected call target: %s", target); } @@ -258,6 +264,8 @@ string CudnnConvKindToString(CudnnConvKind kind) { return "backward_filter"; case CudnnConvKind::kBackwardInput: return "backward_input"; + case CudnnConvKind::kForwardActivation: + return "forward with activation"; } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 744346abf3..a64a616ab1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -44,9 +44,11 @@ namespace gpu { // "connectivity" (i.e. which elements of the input affect which elements of // the output) are concerned. enum class CudnnConvKind { - kForward, // input + filter => output - kBackwardInput, // filter + output => input - kBackwardFilter, // input + output => filter + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter + kForwardActivation, // activation(conv(input, filter) + broadcast(bias) + + // (optionally) side_input) => output }; StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr); @@ -119,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); extern const char* const kCudnnConvForwardCallTarget; extern const char* const kCudnnConvBackwardInputCallTarget; extern const char* const kCudnnConvBackwardFilterCallTarget; +extern const char* const kCudnnConvBiasActivationForwardCallTarget; // Returns true if `hlo` will be implemented as a call to a cuDNN convolution // routine. diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index b0061fa655..2d270f630b 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -209,7 +210,11 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) { std::vector<HloInstruction*> convs; for (HloInstruction* instr : comp->instructions()) { if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16) { + instr->operand(0)->shape().element_type() == F16 && + // TODO(timshen): Disable for fused conv for now. Implement it if it's + // needed. + Cast<HloCustomCallInstruction>(instr)->custom_call_target() != + kCudnnConvBiasActivationForwardCallTarget) { convs.push_back(instr); } } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index eead408f10..7e77dc9ac6 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -162,8 +162,12 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract // out the shape of conv_result. VLOG(1) << "Canonicalizing forward conv"; + std::vector<HloInstruction*> operands(conv->operands().begin(), + conv->operands().end()); + operands[0] = new_input; + operands[1] = new_kernel; auto new_conv = conv->parent()->AddInstruction( - conv->CloneWithNewOperands(conv->shape(), {new_input, new_kernel})); + conv->CloneWithNewOperands(conv->shape(), operands)); new_conv->set_window(new_conv_window); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 9abfa1db6a..621b155240 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -873,7 +873,7 @@ class NormalizeDescriptor { // Describes a kind of non-linearity (threshold-like mathematical function). enum class ActivationMode { - kNone, + kNone = 0, kSigmoid, // Rectified linear activation: f(x) = x < 0 ? 0 : x kRelu, @@ -885,6 +885,8 @@ enum class ActivationMode { kTanh, // Like ReluX, but passes all values in the range [-X,X]. kBandPass, + + kNumActivationModes, // Always in the end. }; // Returns a string representation of the given activation mode. |