aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
blob: d63287539dfde5bb4890ab8303ef2205133d8125 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"

#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
 public:
  explicit BFloat16ConversionFoldingVisitor(
      HloComputation* computation, const BFloat16Support* bfloat16_support)
      : computation_(computation), bfloat16_support_(bfloat16_support) {}

  Status DefaultAction(HloInstruction* hlo) override;

  // Special handling for cross-replica-sum which can have a tuple output.
  Status HandleCrossReplicaSum(HloInstruction* crs) override;

  static bool Run(HloComputation* computation,
                  const BFloat16Support* bfloat16_support) {
    BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
    TF_CHECK_OK(computation->Accept(&visitor));
    return visitor.changed_;
  }

 private:
  // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
  // conversion as output, and folds them to the HLO itself if feasible.
  Status TryFoldBF16Conversions(HloInstruction* hlo);

  // Folds the F32 -> BF16 conversions from the HLO's output.
  //
  // Precondition: all of the HLO's users are F32 -> BF16 conversions.
  Status FoldOutputConversions(HloInstruction* hlo);

  // Folds the BF16 -> F32 conversion operand to the HLO.
  //
  // Precondition: the operand is a F32 -> BF16 conversion.
  Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);

  HloComputation* computation_;
  const BFloat16Support* bfloat16_support_;
  bool changed_ = false;
};

Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
    HloInstruction* hlo) {
  std::vector<HloInstruction*> materialized_users = hlo->users();
  hlo->mutable_shape()->set_element_type(BF16);
  for (auto user : materialized_users) {
    CHECK_EQ(user->opcode(), HloOpcode::kConvert);
    TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
    changed_ = true;
  }
  return Status::OK();
}

Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
    HloInstruction* hlo, int64 operand_index) {
  // The operand is a convert from BF16 to F32.
  auto operand = hlo->mutable_operand(operand_index);
  CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
  TF_RETURN_IF_ERROR(
      hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
  changed_ = true;
  return Status::OK();
}

namespace {

// Returns whether hlo has users and all users are conversions from F32 to BF16.
bool AllUsersAreF32ToBF16Converts(const HloInstruction* hlo) {
  if (hlo->user_count() == 0 || hlo->shape().element_type() != F32) {
    return false;
  }
  for (const auto user : hlo->users()) {
    if (user->opcode() == HloOpcode::kConvert &&
        user->shape().element_type() == BF16) {
      continue;
    }
    return false;
  }
  return true;
}

}  // namespace

Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
    HloInstruction* hlo) {
  std::vector<int64> bf16_to_f32_operands;
  bool has_other_f32_operands = false;
  for (int64 i = 0; i < hlo->operands().size(); ++i) {
    auto operand = hlo->operand(i);
    if (operand->shape().element_type() == F32) {
      if (operand->opcode() == HloOpcode::kConvert &&
          operand->operand(0)->shape().element_type() == BF16 &&
          bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
        // Operand is a convert from BF16 to F32 and we support BF16 input
        // directly in the current HLO at the operand index.
        bf16_to_f32_operands.push_back(i);
      } else {
        has_other_f32_operands = true;
      }
      continue;
    }
  }

  const bool fold_output_conversion =
      AllUsersAreF32ToBF16Converts(hlo) &&
      bfloat16_support_->SupportsBF16Output(*hlo);

  if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
    if (has_other_f32_operands ||
        (!fold_output_conversion && hlo->shape().element_type() == F32)) {
      // Some of the operands/output will remain F32, but we cannot use mixed
      // precisions, so we cannot do anything here.
      return Status::OK();
    }
  }

  if (fold_output_conversion) {
    TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
  }

  for (int64 i : bf16_to_f32_operands) {
    TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
  }
  return Status::OK();
}

Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
  // Do not fold BF16 conversions for instructions related to tuples, entry and
  // exit of a computation, fusion, convert, and control flow.
  if (hlo->opcode() == HloOpcode::kTuple ||            //
      hlo->opcode() == HloOpcode::kGetTupleElement ||  //
      hlo->opcode() == HloOpcode::kInfeed ||           //
      hlo->opcode() == HloOpcode::kOutfeed ||          //
      hlo->opcode() == HloOpcode::kSend ||             //
      hlo->opcode() == HloOpcode::kSendDone ||         //
      hlo->opcode() == HloOpcode::kRecv ||             //
      hlo->opcode() == HloOpcode::kRecvDone ||         //
      hlo->opcode() == HloOpcode::kConstant ||         //
      hlo->opcode() == HloOpcode::kParameter ||        //
      hlo->opcode() == HloOpcode::kFusion ||           //
      hlo->opcode() == HloOpcode::kConvert ||          //
      hlo->opcode() == HloOpcode::kCall ||             //
      hlo->opcode() == HloOpcode::kCustomCall ||       //
      hlo->opcode() == HloOpcode::kWhile ||            //
      hlo->opcode() == HloOpcode::kConditional) {
    return Status::OK();
  }
  if (hlo == computation_->root_instruction() &&
      !bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
    // If hlo is the root instruction, we cannot change its output, so folding
    // can only happen when it supports mixed precision so that we can change
    // its operands.
    return Status::OK();
  }
  return TryFoldBF16Conversions(hlo);
}

Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum(
    HloInstruction* crs) {
  // First use DefaultAction() to handle the operands. It can't handle
  // tuple-shaped output.
  TF_RETURN_IF_ERROR(DefaultAction(crs));

  if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) {
    return Status::OK();
  }

  // If the output is not a tuple, we don't need special handling.
  if (!ShapeUtil::IsTuple(crs->shape())) {
    return Status::OK();
  }

  // If crs is the root instruction, we should keep its original output type.
  // The root instruction implicitly has a use from being the result of the
  // computation, and the code below does not take this use into account.
  if (crs == computation_->root_instruction()) {
    return Status::OK();
  }

  // Then do per-tuple-element handling on the output.
  std::vector<std::vector<HloInstruction*>> per_tuple_element_gtes(
      crs->operand_count());
  for (auto user : crs->users()) {
    if (user->opcode() != HloOpcode::kGetTupleElement) {
      return Status::OK();
    }
    per_tuple_element_gtes[user->tuple_index()].push_back(user);
  }

  for (int64 i = 0; i < crs->operand_count(); ++i) {
    // Fold conversions only when all the get-tuple-elements' users are
    // conversions from F32 to BF16.
    auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() {
      for (auto gte : per_tuple_element_gtes[i]) {
        if (!AllUsersAreF32ToBF16Converts(gte)) {
          return false;
        }
      }
      return true;
    };
    if (!all_gte_users_are_bf16_convert()) {
      continue;
    }

    ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})
        ->set_element_type(BF16);
    for (auto gte : per_tuple_element_gtes[i]) {
      TF_RETURN_IF_ERROR(FoldOutputConversions(gte));
    }
  }

  return Status::OK();
}

StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
  XLA_VLOG_LINES(
      2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
  bool changed = false;
  for (auto* comp : module->MakeNonfusionComputations()) {
    if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
      changed = true;
    }
  }
  XLA_VLOG_LINES(
      2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
  return changed;
}

}  // namespace xla