diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc | 84 |
1 files changed, 49 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index e3869b5c36..8f1f5a7bf5 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -105,38 +105,45 @@ static HloInstruction* PadInstruction(HloInstruction* instr, // Pads the input/output feature dimensions of the given cudnn convolution // custom-call to be multiples of kDesiredNumFeaturesFactor. -static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) { +static StatusOr<bool> PadFeaturesDims(HloCustomCallInstruction* conv) { CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) << "conv must use 0 scratch bytes, i.e. this pass must be run " "before CudnnConvolutionAlgorithmPicker."; - const auto& target = conv->custom_call_target(); + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); const auto& dnums = conv->convolution_dimension_numbers(); auto* lhs = conv->mutable_operand(0); auto* rhs = conv->mutable_operand(1); const Shape& result_shape = conv->shape().tuple_shapes(0); Shape new_lhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardFilterCallTarget) { - // LHS is "input". - return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kBackwardFilter: + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + case CudnnConvKind::kBackwardInput: + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); - // LHS is "output". - return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); }(); Shape new_rhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardInputCallTarget) { - // RHS is "filter". - return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kBackwardInput: + // RHS is "filter". + return PadShape(rhs->shape(), + {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + case CudnnConvKind::kBackwardFilter: + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // RHS is "output". - return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); }(); if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && @@ -146,18 +153,21 @@ static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) { } Shape new_result_shape = [&] { - if (target == kCudnnConvForwardCallTarget) { - // Result is "output". - return PadShape(result_shape, {dnums.output_feature_dimension()}); + switch (kind) { + case CudnnConvKind::kForward: + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + case CudnnConvKind::kBackwardInput: + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + case CudnnConvKind::kBackwardFilter: + // Result is "filter". + return PadShape(result_shape, + {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + case CudnnConvKind::kForwardActivation: + LOG(FATAL) << "Not yet implemented."; } - if (target == kCudnnConvBackwardInputCallTarget) { - // Result is "input". - return PadShape(result_shape, {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // Result is "filter". - return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); }(); // Check that padding wouldn't increase the total bytes read/written by this @@ -223,16 +233,20 @@ static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) { return true; } -static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) { - std::vector<HloInstruction*> convs; +static std::vector<HloCustomCallInstruction*> GetRelevantConvs( + HloComputation* comp) { + std::vector<HloCustomCallInstruction*> convs; for (HloInstruction* instr : comp->instructions()) { - if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16 && + if (!IsCustomCallToDnnConvolution(*instr)) { + continue; + } + auto* custom_call = Cast<HloCustomCallInstruction>(instr); + if (custom_call->operand(0)->shape().element_type() == F16 && // TODO(timshen): Disable for fused conv for now. Implement it if it's // needed. - Cast<HloCustomCallInstruction>(instr)->custom_call_target() != + custom_call->custom_call_target() != kCudnnConvBiasActivationForwardCallTarget) { - convs.push_back(instr); + convs.push_back(custom_call); } } return convs; @@ -241,7 +255,7 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) { StatusOr<bool> PadForTensorCores::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* conv : GetRelevantConvs(comp)) { + for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); changed |= result; } |