From 5be479930d3dcfa3edb863703b1d73b89d45f03c Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 9 Oct 2018 17:19:24 -0700 Subject: [XLA:GPU] Use CudnnConvKind in more places. No functional change. PiperOrigin-RevId: 216451881 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/cudnn_convolution_runner.cc | 99 +++++++++++----------- .../xla/service/gpu/pad_for_tensor_cores.cc | 84 ++++++++++-------- .../compiler/xla/service/gpu/pad_insertion.cc | 31 +++---- 4 files changed, 116 insertions(+), 99 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 0144d59097..62da43d68a 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -591,6 +591,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 89dd1bb272..a809c22b33 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -312,11 +312,12 @@ StatusOr GetCudnnConvParams( TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, conv->backend_config()); - const auto& target = conv->custom_call_target(); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv)); const auto& lhs_shape = conv->operand(0)->shape(); const auto& rhs_shape = conv->operand(1)->shape(); const auto& conv_result_shape = conv->shape().tuple_shapes(0); + params.kind = kind; params.window = &conv->window(); params.dnums = &conv->convolution_dimension_numbers(); params.feature_group_count = conv->feature_group_count(); @@ -324,55 +325,55 @@ StatusOr GetCudnnConvParams( backend_config.algorithm(), backend_config.tensor_ops_enabled())); params.conv_result_scale = backend_config.conv_result_scale(); - if (target == kCudnnConvForwardCallTarget) { - params.kind = CudnnConvKind::kForward; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - } else if (target == kCudnnConvBackwardInputCallTarget) { - params.kind = CudnnConvKind::kBackwardInput; - params.input_shape = &conv_result_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &lhs_shape; - params.input_buf = result_buffer; - params.filter_buf = operand_buffers[1]; - params.output_buf = operand_buffers[0]; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - params.kind = CudnnConvKind::kBackwardFilter; - params.input_shape = &lhs_shape; - params.filter_shape = &conv_result_shape; - params.output_shape = &rhs_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = result_buffer; - params.output_buf = operand_buffers[1]; - } else if (target == kCudnnConvBiasActivationForwardCallTarget) { - params.kind = CudnnConvKind::kForwardActivation; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.fusion.emplace(); - auto& fusion = *params.fusion; - if (backend_config.activation_mode() < - static_cast(se::dnn::ActivationMode::kNumActivationModes)) { - fusion.mode = static_cast( - backend_config.activation_mode()); - } else { - return InternalError("Bad activation mode: %s", - backend_config.ShortDebugString()); - } - fusion.side_input_scale = backend_config.side_input_scale(); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - params.fusion->bias_buf = operand_buffers[2]; - if (operand_buffers.size() >= 4) { - params.fusion->side_input_buf = operand_buffers[3]; + switch (kind) { + case CudnnConvKind::kForward: + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; + case CudnnConvKind::kBackwardInput: + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + case CudnnConvKind::kForwardActivation: { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } } - } else { - return InternalError("Unexpected custom call target: %s", target); } return params; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index e3869b5c36..8f1f5a7bf5 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -105,38 +105,45 @@ static HloInstruction* PadInstruction(HloInstruction* instr, // Pads the input/output feature dimensions of the given cudnn convolution // custom-call to be multiples of kDesiredNumFeaturesFactor. -static StatusOr PadFeaturesDims(HloInstruction* conv) { +static StatusOr PadFeaturesDims(HloCustomCallInstruction* conv) { CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) << "conv must use 0 scratch bytes, i.e. this pass must be run " "before CudnnConvolutionAlgorithmPicker."; - const auto& target = conv->custom_call_target(); + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); const auto& dnums = conv->convolution_dimension_numbers(); auto* lhs = conv->mutable_operand(0); auto* rhs = conv->mutable_operand(1); const Shape& result_shape = conv->shape().tuple_shapes(0); Shape new_lhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardFilterCallTarget) { - // LHS is "input". - return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kBackwardFilter: + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + case CudnnConvKind::kBackwardInput: + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); - // LHS is "output". - return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); }(); Shape new_rhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardInputCallTarget) { - // RHS is "filter". - return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kBackwardInput: + // RHS is "filter". + return PadShape(rhs->shape(), + {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + case CudnnConvKind::kBackwardFilter: + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // RHS is "output". - return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); }(); if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && @@ -146,18 +153,21 @@ static StatusOr PadFeaturesDims(HloInstruction* conv) { } Shape new_result_shape = [&] { - if (target == kCudnnConvForwardCallTarget) { - // Result is "output". - return PadShape(result_shape, {dnums.output_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + case CudnnConvKind::kBackwardInput: + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + case CudnnConvKind::kBackwardFilter: + // Result is "filter". + return PadShape(result_shape, + {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - if (target == kCudnnConvBackwardInputCallTarget) { - // Result is "input". - return PadShape(result_shape, {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // Result is "filter". - return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); }(); // Check that padding wouldn't increase the total bytes read/written by this @@ -223,16 +233,20 @@ static StatusOr PadFeaturesDims(HloInstruction* conv) { return true; } -static std::vector GetRelevantConvs(HloComputation* comp) { - std::vector convs; +static std::vector GetRelevantConvs( + HloComputation* comp) { + std::vector convs; for (HloInstruction* instr : comp->instructions()) { - if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16 && + if (!IsCustomCallToDnnConvolution(*instr)) { + continue; + } + auto* custom_call = Cast(instr); + if (custom_call->operand(0)->shape().element_type() == F16 && // TODO(timshen): Disable for fused conv for now. Implement it if it's // needed. - Cast(instr)->custom_call_target() != + custom_call->custom_call_target() != kCudnnConvBiasActivationForwardCallTarget) { - convs.push_back(instr); + convs.push_back(custom_call); } } return convs; @@ -241,7 +255,7 @@ static std::vector GetRelevantConvs(HloComputation* comp) { StatusOr PadForTensorCores::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* conv : GetRelevantConvs(comp)) { + for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b42a19e3a2..ae7abca7c6 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -378,25 +379,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - std::vector convs; + std::vector convs; for (auto* instr : computation->instructions()) { if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(instr); + convs.push_back(Cast(instr)); } } - for (HloInstruction* instruction : convs) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBiasActivationForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + for (HloCustomCallInstruction* instruction : convs) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); + changed |= [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return CanonicalizeForwardConvolution(instruction); + case CudnnConvKind::kBackwardInput: + return CanonicalizeBackwardInputConvolution(instruction); + case CudnnConvKind::kBackwardFilter: + return CanonicalizeBackwardFilterConvolution(instruction); + } + }(); } return changed; } -- cgit v1.2.3