aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
blob: 39036e205e76979e7da08246cd030ebd17e52f76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"

#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"

namespace xla {

namespace {

struct PassThrough {
  PassThrough(HloInstruction* user, HloInstruction* operand)
      : user(user), operand(operand) {}

  HloInstruction* user = nullptr;
  HloInstruction* operand = nullptr;
};

void SetSingleSharding(HloInstruction* instruction,
                       const HloSharding& sharding) {
  VLOG(4) << "  " << instruction->name() << " to " << sharding;
  instruction->set_single_sharding(sharding);
}

bool ShardingMatches(const HloSharding& sharding1,
                     const HloSharding& sharding2) {
  auto single_sharding1 = sharding1.ExtractSingleSharding();
  if (single_sharding1) {
    auto single_sharding2 = sharding2.ExtractSingleSharding();
    if (single_sharding2) {
      return *single_sharding1 == single_sharding2;
    }
  }
  // Anything which is not unique across all elements, gets a full sharding
  // compare.
  return sharding1 == sharding2;
}

// When we create domains, they are never "empty", where with empty we mean
// that a kDomain instruction has as operand another kDomain instruction of the
// same kind.
// But when the HLO optimizations are run, empty domains can be created.
// For example:
//
//  Domain(device=None, device=0) ->
//    Tuple(device=0) ->
//      GTE(device=0) ->
//        Domain(device=0, device=None)
//
// In that case the tuple simplifier could create something like:
//
//  Domain(device=None, device=0) -> Domain(device=0, device=None)
//
// Which is a so called empty domain.
// In the case above, crossing an empty domain which was transiting through
// device 0, requires the normalization phase to fixup the empty domain by
// adding back a Tuple+GTE pair with the proper device.
// One particular case where this can create problems is the result of the
// entry computation, where the GTE assignments are used by TF to tell the
// XLA where the results should be sent.
std::vector<PassThrough> LocatePassThroughDomainLinks(
    const DomainMetadata::Domain& domain) {
  std::vector<PassThrough> pass_through;
  for (HloInstruction* instruction : domain.enter_domains) {
    CHECK(instruction->opcode() == HloOpcode::kDomain)
        << "Instruction is not a kDomain: " << instruction->ToString();
    for (HloInstruction* user : instruction->users()) {
      if (user->opcode() == HloOpcode::kDomain &&
          domain.exit_domains.count(user) != 0) {
        pass_through.emplace_back(user, instruction);
        VLOG(2) << "Found passthrough domain link:";
        VLOG(2) << "  " << user->ToString();
        VLOG(2) << "  " << instruction->ToString();
      }
    }
  }
  return pass_through;
}

Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
                                   const HloSharding& sharding) {
  for (auto& pass_through : LocatePassThroughDomainLinks(domain)) {
    HloInstruction* tuple = pass_through.operand->parent()->AddInstruction(
        HloInstruction::CreateTuple({pass_through.operand}));
    HloInstruction* gte = pass_through.operand->parent()->AddInstruction(
        HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
                                              tuple, 0));
    gte->set_sharding(sharding);
    TF_RETURN_IF_ERROR(
        pass_through.operand->ReplaceUseWith(pass_through.user, gte));
  }
  return Status::OK();
}

std::unique_ptr<HloSharding> CloneShardingForDomain(
    const HloSharding& sharding) {
  auto single_sharding = sharding.ExtractSingleSharding();
  if (!single_sharding) {
    return MakeUnique<HloSharding>(sharding);
  }
  return MakeUnique<HloSharding>(*single_sharding);
}

Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
                                 const HloSharding& sharding) {
  VLOG(4) << "Applying " << sharding << " sharding";
  for (HloInstruction* instruction : domain.instructions) {
    // We only change instructions without sharding, since otherwise we might
    // mess up with eventual HLO passes which has knowledge of it.
    if (!instruction->has_sharding()) {
      SetSingleSharding(instruction, sharding);
    } else {
      VLOG(4) << "  " << instruction->name() << " already has sharding "
              << instruction->sharding();
    }
  }
  return Status::OK();
}

// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree.
// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate()
// sharding will be returned.
ShapeTree<HloSharding> GetTupleSharding(HloInstruction* tuple) {
  if (tuple->has_sharding()) {
    return tuple->sharding().GetAsShapeTree(tuple->shape());
  }
  return ShapeTree<HloSharding>(tuple->shape(), HloSharding::Replicate());
}

// Retrieves the sharding of operand, asked from a user instruction which is
// within domain. If operand is a kDomain, it means that sharding argument is
// the operand sharding, otherwise the operand's own sharding will be returned.
const HloSharding* GetOperandSharding(const HloInstruction* operand,
                                      const DomainMetadata::Domain& domain,
                                      const HloSharding& sharding) {
  DCHECK_EQ(domain.reach_set.count(const_cast<HloInstruction*>(operand)), 1);
  // Here the user of operand is within the domain instruction set, and since it
  // is user of operand, we need to look into the enter_domains set. If this is
  // not a kDomain within the user domains set, then return the operand
  // sharding, if any.
  if (operand->opcode() != HloOpcode::kDomain ||
      domain.enter_domains.count(const_cast<HloInstruction*>(operand)) == 0) {
    return operand->has_sharding() ? &operand->sharding() : nullptr;
  }
  // At this point operand is a kDomain of the currently processed domain, so we
  // can refer to sharding as the domain sharding.
  return &sharding;
}

// Tries to propagate the sharding information into the instructions that are
// part of the domain, in a post order manner (operand propagate to user).
StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
                                        const HloSharding& sharding) {
  int64 assigned = 0;
  for (HloInstruction* instruction : domain.instructions) {
    if (instruction->has_sharding()) {
      continue;
    }
    if (instruction->opcode() == HloOpcode::kGetTupleElement) {
      HloInstruction* tuple = instruction->mutable_operand(0);
      const HloSharding* tuple_sharding =
          GetOperandSharding(tuple, domain, sharding);
      if (tuple_sharding != nullptr) {
        if (tuple_sharding->IsTuple()) {
          HloSharding sub_sharding = tuple_sharding->GetSubSharding(
              tuple->shape(), {instruction->tuple_index()});
          VLOG(4) << "  " << instruction->name() << " to sharding "
                  << sub_sharding;
          instruction->set_sharding(sub_sharding);
        } else {
          SetSingleSharding(instruction, *tuple_sharding);
        }
        ++assigned;
      }
    } else if (instruction->opcode() == HloOpcode::kTuple) {
      int64 tuple_assigned = 0;
      ShapeTree<HloSharding> shape_tree = GetTupleSharding(instruction);
      for (int64 i = 0; i < instruction->operand_count(); ++i) {
        const HloSharding* operand_sharding =
            GetOperandSharding(instruction->operand(i), domain, sharding);
        if (operand_sharding != nullptr &&
            shape_tree.element({i}) != *operand_sharding) {
          *shape_tree.mutable_element({i}) = *operand_sharding;
          ++tuple_assigned;
        }
      }
      if (tuple_assigned > 0) {
        HloSharding tuple_sharding = HloSharding::Tuple(shape_tree);
        VLOG(4) << "  " << instruction->name() << " to sharding "
                << tuple_sharding;
        instruction->set_sharding(tuple_sharding);
        ++assigned;
      }
    } else {
      // If all the operand of the given instruction has the same single device
      // assignment, assign that device to this instruction as well.
      const HloSharding* common_sharding = nullptr;
      for (const HloInstruction* operand : instruction->operands()) {
        const HloSharding* operand_sharding =
            GetOperandSharding(operand, domain, sharding);
        if (operand_sharding != nullptr) {
          if (common_sharding != nullptr &&
              *common_sharding != *operand_sharding) {
            common_sharding = nullptr;
            break;
          }
          common_sharding = operand_sharding;
        }
      }
      if (common_sharding != nullptr) {
        VLOG(4) << "  " << instruction->name() << " to sharding "
                << *common_sharding;
        instruction->set_sharding(*common_sharding);
        ++assigned;
      }
    }
  }
  return assigned;
}

Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
                           const HloSharding& sharding) {
  // Here is the place to call external sharding normalizers, which are
  // implemented in other modules (ie, spatial partitioning).
  // The signature of the external normalizer function should be something
  // like:
  //
  //   StatusOr<bool> Normalizer(const DomainMetadata::Domain&,
  //                             const HloSharding& sharding);
  //
  // The function should return true if it has processed the domain
  // normalization, false if domain was not one recognized by it, or an error.
  // We will call the functions in order below, and fall back to local code if
  // none of the external normalizers acted on the domain.
  // External normalizers should not handle the cases that are already handled
  // locally.

  // None of the external normalizers handled the domain sharding, try to see
  // whether this is a single sharding first.
  auto single_sharding = sharding.ExtractSingleSharding();
  if (single_sharding) {
    // Shortcut the simple case. We have a unique sharding, so we call
    // the ApplyDomainSingleSharding() API which will apply array or tuple
    // shaped sharding to the domain instructions.
    return ApplyDomainSingleSharding(domain, *single_sharding);
  }
  VLOG(1) << "Assigning non-trivial sharding " << sharding;
  for (;;) {
    TF_ASSIGN_OR_RETURN(int64 assigned,
                        ApplyDomainShardingPass(domain, sharding));
    if (assigned == 0) {
      break;
    }
  }
  int64 unassigned = 0;
  for (HloInstruction* instruction : domain.instructions) {
    if (!instruction->has_sharding()) {
      LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
      ++unassigned;
    }
  }
  // Should we error out if unassigned > 0?
  return Status::OK();
}

// Creates a kDomain instruction to be placed between instruction and operand.
// The kDomain instruction will be created only if the sharding differ between
// the instruction and the operand.
std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
                                             HloInstruction* operand) {
  const HloSharding* instruction_sharding =
      instruction->has_sharding() ? &instruction->sharding() : nullptr;
  const HloSharding* operand_sharding =
      operand->has_sharding() ? &operand->sharding() : nullptr;
  // No need for domain if they both have no sharding.
  if (instruction_sharding == nullptr && operand_sharding == nullptr) {
    return nullptr;
  }
  // No need for domain if they match.
  if (instruction_sharding != nullptr && operand_sharding != nullptr &&
      ShardingMatches(*instruction_sharding, *operand_sharding)) {
    return nullptr;
  }
  std::unique_ptr<HloSharding> real_instruction_sharding;
  std::unique_ptr<HloSharding> real_operand_sharding;
  if (instruction_sharding != nullptr) {
    real_instruction_sharding = CloneShardingForDomain(*instruction_sharding);
  }
  if (operand_sharding != nullptr) {
    real_operand_sharding = CloneShardingForDomain(*operand_sharding);
  }
  VLOG(3) << "Creating domain:";
  VLOG(3) << "  Instruction: " << instruction->name();
  VLOG(3) << "  Operand: " << operand->name();
  VLOG(3) << "    User side sharding: "
          << (real_instruction_sharding != nullptr
                  ? real_instruction_sharding->ToString()
                  : "None");
  VLOG(3) << "    Operand side sharding: "
          << (real_operand_sharding != nullptr
                  ? real_operand_sharding->ToString()
                  : "None");

  std::unique_ptr<DomainMetadata> operand_side_metadata =
      MakeUnique<ShardingMetadata>(std::move(real_operand_sharding));
  std::unique_ptr<DomainMetadata> user_side_metadata =
      MakeUnique<ShardingMetadata>(std::move(real_instruction_sharding));
  return HloInstruction::CreateDomain(operand->shape(), operand,
                                      std::move(operand_side_metadata),
                                      std::move(user_side_metadata));
}

StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding(
    tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
  // If we are here, all the instructions being passed had the same sharding
  // (or no sharding), by the means of the ShardingMatches() API.
  // As such, no kDomain was inserted, and here we are asked to extract the
  // original common sharding.
  // All the instructions passed to this API are part of the same computation.
  const HloSharding* sharding = nullptr;
  for (HloInstruction* instruction : instructions) {
    if (instruction->has_sharding()) {
      if (sharding == nullptr) {
        sharding = &instruction->sharding();
      } else {
        TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
            << "Sharding " << *sharding << " does not match the one in "
            << instruction->ToString();
      }
    }
  }
  if (sharding == nullptr) {
    return std::unique_ptr<HloSharding>();
  }
  VLOG(4) << "Extracted sharding is " << *sharding;
  return CloneShardingForDomain(*sharding);
}

}  // namespace

std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
  std::unique_ptr<HloSharding> sharding;
  if (sharding_ != nullptr) {
    sharding = MakeUnique<HloSharding>(*sharding_);
  }
  return MakeUnique<ShardingMetadata>(std::move(sharding));
}

bool ShardingMetadata::Matches(const DomainMetadata& other) const {
  const ShardingMetadata* other_ptr =
      dynamic_cast<const ShardingMetadata*>(&other);
  if (other_ptr == nullptr) {
    // If other is not a ShardingMetadata, then it is clearly a no match.
    return false;
  }
  if (sharding_ == nullptr) {
    return other_ptr->sharding_ == nullptr;
  }
  return other_ptr->sharding_ != nullptr
             ? ShardingMatches(*sharding_, *other_ptr->sharding_)
             : false;
}

string ShardingMetadata::ToString() const {
  return sharding_ != nullptr ? sharding_->ToString() : "{}";
}

Status ShardingMetadata::NormalizeInstructions(
    const DomainMetadata::Domain& domain) const {
  if (sharding_ != nullptr) {
    VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":";
    TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_));
    TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_));
  }
  return Status::OK();
}

Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) {
  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding,
                      ExtractOriginalCommonSharding(domain.instructions));
  if (sharding != nullptr) {
    VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString()
            << ":";
    TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
  } else {
    VLOG(1) << "Unable to find common sharding";
  }
  return Status::OK();
}

std::unique_ptr<HloInstruction> CreateShardingDomain(
    HloInstruction* instruction, HloInstruction* operand) {
  return CreateDomain(instruction, operand);
}

}  // namespace xla