aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc31
1 files changed, 16 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index b42a19e3a2..ae7abca7c6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@@ -378,25 +379,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
bool changed = false;
- std::vector<HloInstruction*> convs;
+ std::vector<HloCustomCallInstruction*> convs;
for (auto* instr : computation->instructions()) {
if (IsCustomCallToDnnConvolution(*instr)) {
- convs.push_back(instr);
+ convs.push_back(Cast<HloCustomCallInstruction>(instr));
}
}
- for (HloInstruction* instruction : convs) {
- const auto& target = instruction->custom_call_target();
- if (target == kCudnnConvForwardCallTarget ||
- target == kCudnnConvBiasActivationForwardCallTarget) {
- changed |= CanonicalizeForwardConvolution(instruction);
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- changed |= CanonicalizeBackwardFilterConvolution(instruction);
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- changed |= CanonicalizeBackwardInputConvolution(instruction);
- } else {
- LOG(FATAL) << "Unknown custom call target for cudnn conv: "
- << instruction->ToString();
- }
+ for (HloCustomCallInstruction* instruction : convs) {
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
+ changed |= [&] {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ return CanonicalizeForwardConvolution(instruction);
+ case CudnnConvKind::kBackwardInput:
+ return CanonicalizeBackwardInputConvolution(instruction);
+ case CudnnConvKind::kBackwardFilter:
+ return CanonicalizeBackwardFilterConvolution(instruction);
+ }
+ }();
}
return changed;
}