aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc84
1 files changed, 49 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index e3869b5c36..8f1f5a7bf5 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -105,38 +105,45 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
// Pads the input/output feature dimensions of the given cudnn convolution
// custom-call to be multiples of kDesiredNumFeaturesFactor.
-static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
+static StatusOr<bool> PadFeaturesDims(HloCustomCallInstruction* conv) {
CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
<< "conv must use 0 scratch bytes, i.e. this pass must be run "
"before CudnnConvolutionAlgorithmPicker.";
- const auto& target = conv->custom_call_target();
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
const auto& dnums = conv->convolution_dimension_numbers();
auto* lhs = conv->mutable_operand(0);
auto* rhs = conv->mutable_operand(1);
const Shape& result_shape = conv->shape().tuple_shapes(0);
Shape new_lhs_shape = [&] {
- if (target == kCudnnConvForwardCallTarget ||
- target == kCudnnConvBackwardFilterCallTarget) {
- // LHS is "input".
- return PadShape(lhs->shape(), {dnums.input_feature_dimension()});
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kBackwardFilter:
+ // LHS is "input".
+ return PadShape(lhs->shape(), {dnums.input_feature_dimension()});
+ case CudnnConvKind::kBackwardInput:
+ // LHS is "output".
+ return PadShape(lhs->shape(), {dnums.output_feature_dimension()});
+ case CudnnConvKind::kForwardActivation:
+ LOG(FATAL) << "Not yet implemented.";
}
- CHECK_EQ(target, kCudnnConvBackwardInputCallTarget);
- // LHS is "output".
- return PadShape(lhs->shape(), {dnums.output_feature_dimension()});
}();
Shape new_rhs_shape = [&] {
- if (target == kCudnnConvForwardCallTarget ||
- target == kCudnnConvBackwardInputCallTarget) {
- // RHS is "filter".
- return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()});
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kBackwardInput:
+ // RHS is "filter".
+ return PadShape(rhs->shape(),
+ {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ case CudnnConvKind::kBackwardFilter:
+ // RHS is "output".
+ return PadShape(rhs->shape(), {dnums.output_feature_dimension()});
+ case CudnnConvKind::kForwardActivation:
+ LOG(FATAL) << "Not yet implemented.";
}
- CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
- // RHS is "output".
- return PadShape(rhs->shape(), {dnums.output_feature_dimension()});
}();
if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) &&
@@ -146,18 +153,21 @@ static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
}
Shape new_result_shape = [&] {
- if (target == kCudnnConvForwardCallTarget) {
- // Result is "output".
- return PadShape(result_shape, {dnums.output_feature_dimension()});
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ // Result is "output".
+ return PadShape(result_shape, {dnums.output_feature_dimension()});
+ case CudnnConvKind::kBackwardInput:
+ // Result is "input".
+ return PadShape(result_shape, {dnums.input_feature_dimension()});
+ case CudnnConvKind::kBackwardFilter:
+ // Result is "filter".
+ return PadShape(result_shape,
+ {dnums.kernel_input_feature_dimension(),
+ dnums.kernel_output_feature_dimension()});
+ case CudnnConvKind::kForwardActivation:
+ LOG(FATAL) << "Not yet implemented.";
}
- if (target == kCudnnConvBackwardInputCallTarget) {
- // Result is "input".
- return PadShape(result_shape, {dnums.input_feature_dimension()});
- }
- CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget);
- // Result is "filter".
- return PadShape(result_shape, {dnums.kernel_input_feature_dimension(),
- dnums.kernel_output_feature_dimension()});
}();
// Check that padding wouldn't increase the total bytes read/written by this
@@ -223,16 +233,20 @@ static StatusOr<bool> PadFeaturesDims(HloInstruction* conv) {
return true;
}
-static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
- std::vector<HloInstruction*> convs;
+static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
+ HloComputation* comp) {
+ std::vector<HloCustomCallInstruction*> convs;
for (HloInstruction* instr : comp->instructions()) {
- if (IsCustomCallToDnnConvolution(*instr) &&
- instr->operand(0)->shape().element_type() == F16 &&
+ if (!IsCustomCallToDnnConvolution(*instr)) {
+ continue;
+ }
+ auto* custom_call = Cast<HloCustomCallInstruction>(instr);
+ if (custom_call->operand(0)->shape().element_type() == F16 &&
// TODO(timshen): Disable for fused conv for now. Implement it if it's
// needed.
- Cast<HloCustomCallInstruction>(instr)->custom_call_target() !=
+ custom_call->custom_call_target() !=
kCudnnConvBiasActivationForwardCallTarget) {
- convs.push_back(instr);
+ convs.push_back(custom_call);
}
}
return convs;
@@ -241,7 +255,7 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
StatusOr<bool> PadForTensorCores::Run(HloModule* module) {
bool changed = false;
for (HloComputation* comp : module->MakeNonfusionComputations()) {
- for (HloInstruction* conv : GetRelevantConvs(comp)) {
+ for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv));
changed |= result;
}