From 5be479930d3dcfa3edb863703b1d73b89d45f03c Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 9 Oct 2018 17:19:24 -0700 Subject: [XLA:GPU] Use CudnnConvKind in more places. No functional change. PiperOrigin-RevId: 216451881 --- .../compiler/xla/service/gpu/pad_insertion.cc | 31 +++++++++++----------- 1 file changed, 16 insertions(+), 15 deletions(-) (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc') diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b42a19e3a2..ae7abca7c6 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -378,25 +379,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - std::vector convs; + std::vector convs; for (auto* instr : computation->instructions()) { if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(instr); + convs.push_back(Cast(instr)); } } - for (HloInstruction* instruction : convs) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBiasActivationForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + for (HloCustomCallInstruction* instruction : convs) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); + changed |= [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return CanonicalizeForwardConvolution(instruction); + case CudnnConvKind::kBackwardInput: + return CanonicalizeBackwardInputConvolution(instruction); + case CudnnConvKind::kBackwardFilter: + return CanonicalizeBackwardFilterConvolution(instruction); + } + }(); } return changed; } -- cgit v1.2.3