aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
blob: e3f4a9852ace86c20610362aa6ad3c3d9c78de30 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"

#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"

namespace xla {

namespace {

// AssignmentKind and kUnassignedDevice are used during tuple domain sharding
// propagation in order to distinguish among three cases:
// kUnassigned: no assignment has occurred
// kAssigned: at least an assignment has occurred
// kConflict: no assignment has occurred because of conflicting propagations,
// which occurs when multiple users of an instruction have different
// shardings.
enum class AssignmentKind { kUnassigned, kAssigned, kConflict };

// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate
// absence of sharding information for that particular sub-sharding during
// sharding propagation. It is used to be able to express tuple shardings with
// partial information. At the end of the propagation the sharding of
// tuple-shaped instructions using kUnassignedDevice's is cleared.
// TODO(b/112883246): Centralized enum of reserved devices.
constexpr int64 kUnassignedDevice = -2;

struct PassThrough {
  PassThrough(HloInstruction* user, HloInstruction* operand)
      : user(user), operand(operand) {}

  HloInstruction* user = nullptr;
  HloInstruction* operand = nullptr;
};

void SetSingleSharding(HloInstruction* instruction,
                       const HloSharding& sharding) {
  VLOG(4) << "  " << instruction->name() << " to " << sharding;
  instruction->set_single_sharding(sharding);
}

bool ShardingMatches(const HloSharding& sharding1,
                     const HloSharding& sharding2) {
  auto single_sharding1 = sharding1.ExtractSingleSharding();
  if (single_sharding1) {
    auto single_sharding2 = sharding2.ExtractSingleSharding();
    if (single_sharding2) {
      return *single_sharding1 == single_sharding2;
    }
  }
  // Anything which is not unique across all elements, gets a full sharding
  // compare.
  return sharding1 == sharding2;
}

// When we create domains, they are never "empty", where with empty we mean
// that a kDomain instruction has as operand another kDomain instruction of the
// same kind.
// But when the HLO optimizations are run, empty domains can be created.
// For example:
//
//  Domain(device=None, device=0) ->
//    Tuple(device=0) ->
//      GTE(device=0) ->
//        Domain(device=0, device=None)
//
// In that case the tuple simplifier could create something like:
//
//  Domain(device=None, device=0) -> Domain(device=0, device=None)
//
// Which is a so called empty domain.
// In the case above, crossing an empty domain which was transiting through
// device 0, requires the normalization phase to fixup the empty domain by
// adding back a Tuple+GTE pair with the proper device.
// One particular case where this can create problems is the result of the
// entry computation, where the GTE assignments are used by TF to tell the
// XLA where the results should be sent.
std::vector<PassThrough> LocatePassThroughDomainLinks(
    const DomainMetadata::Domain& domain) {
  std::vector<PassThrough> pass_through;
  for (HloInstruction* instruction : domain.enter_domains) {
    CHECK(instruction->opcode() == HloOpcode::kDomain)
        << "Instruction is not a kDomain: " << instruction->ToString();
    for (HloInstruction* user : instruction->users()) {
      if (user->opcode() == HloOpcode::kDomain &&
          domain.exit_domains.count(user) != 0) {
        pass_through.emplace_back(user, instruction);
        VLOG(2) << "Found passthrough domain link:";
        VLOG(2) << "  " << user->ToString();
        VLOG(2) << "  " << instruction->ToString();
      }
    }
    if (instruction == instruction->parent()->root_instruction()) {
      pass_through.emplace_back(nullptr, instruction);
      VLOG(2) << "Found passthrough domain link:";
      VLOG(2) << "  <root>";
      VLOG(2) << "  " << instruction->ToString();
    }
  }
  return pass_through;
}

Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
                                   const HloSharding& sharding) {
  for (auto& pass_through : LocatePassThroughDomainLinks(domain)) {
    HloInstruction* tuple = pass_through.operand->parent()->AddInstruction(
        HloInstruction::CreateTuple({pass_through.operand}));
    HloInstruction* gte = pass_through.operand->parent()->AddInstruction(
        HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
                                              tuple, 0));
    gte->set_sharding(sharding);
    if (pass_through.user != nullptr) {
      TF_RETURN_IF_ERROR(
          pass_through.operand->ReplaceUseWith(pass_through.user, gte));
    } else {
      pass_through.operand->parent()->set_root_instruction(gte);
    }
  }
  return Status::OK();
}

// For tuple shardings if every element have the same sharsing then we want to
// treat them as single element sharsings to insert less domain separation as a
// domain can prevent some optimizations and we want to minimize that from
// happening.
std::shared_ptr<const HloSharding> CloneShardingForDomain(
    std::shared_ptr<const HloSharding> sharding) {
  auto single_sharding = sharding->ExtractSingleSharding();
  if (!single_sharding) {
    return sharding;
  }
  return std::make_shared<const HloSharding>(*single_sharding);
}

Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
                                 const HloSharding& sharding) {
  VLOG(4) << "Applying " << sharding << " sharding";
  for (HloInstruction* instruction : domain.instructions) {
    // We only change instructions without sharding, since otherwise we might
    // mess up with eventual HLO passes which has knowledge of it.
    if (!instruction->has_sharding()) {
      SetSingleSharding(instruction, sharding);
    } else {
      VLOG(4) << "  " << instruction->name() << " already has sharding "
              << instruction->sharding();
    }
  }
  return Status::OK();
}

// Return the ShapeTree<HloSharding> of the user argument. The user argument
// is assumed to be a user of the instruction argument.
// If user is a tuple instruction, return the tuple subsharding corresponding to
// the operand matching the instruction argument, because that is the
// subsharding corresponding to instruction.
ShapeTree<HloSharding> GetShardingTreeFromUser(
    const HloInstruction& instruction, const HloInstruction& user) {
  if (user.opcode() == HloOpcode::kTuple) {
    return user.sharding()
        .GetSubSharding(user.shape(), {user.operand_index(&instruction)})
        .GetAsShapeTree(instruction.shape());
  }
  return user.sharding().GetAsShapeTree(user.shape());
}

// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice)
// then no assignment is made. Therefore kUnassignedDevice is never propagated.
// kConflict is returned if lhs is already assigned and rhs is assigned to a
// different device.
StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs,
                                            const HloSharding& rhs) {
  TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple());
  if (rhs.UsesDevice(kUnassignedDevice)) {
    return AssignmentKind::kUnassigned;
  }
  if (lhs->UsesDevice(kUnassignedDevice)) {
    *lhs = rhs;
    return AssignmentKind::kAssigned;
  }
  return lhs->UniqueDevice() != rhs.UniqueDevice()
             ? AssignmentKind::kConflict
             : AssignmentKind::kUnassigned;
}

// Assigns the whole rhs tree to lhs_tree, starting at lhs_it.
// In case of conflicting assignment AssignmentKind::kConflict is returned. In
// this case lhs_tree is partially assigned, up to the conflicting leaf. It is
// up to the caller to discard the partial assignment in case of conflict.
StatusOr<AssignmentKind> AssignTreeSharding(
    ShapeTree<HloSharding>* lhs_tree, ShapeTree<HloSharding>::iterator lhs_it,
    const ShapeTree<HloSharding>& rhs_tree) {
  AssignmentKind assigned = AssignmentKind::kUnassigned;
  auto rhs_it = rhs_tree.begin();
  for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end();
       ++lhs_it, ++rhs_it) {
    // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it)
    if (rhs_tree.IsLeaf(rhs_it->first)) {
      TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first));
      TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned,
                          AssignLeafSharding(&lhs_it->second, rhs_it->second));
      if (sub_assigned == AssignmentKind::kConflict) {
        // In case of conflict we return conflict to the caller. At this point
        // partial assignments to lhs_tree may have been made already. It is up
        // to the caller to discard the partial assignment in case of conflict.
        return AssignmentKind::kConflict;
      } else if (sub_assigned == AssignmentKind::kAssigned) {
        assigned = sub_assigned;
      }
    }
  }
  TF_RET_CHECK(rhs_it == rhs_tree.end());
  return assigned;
}

StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
                                      const DomainMetadata::Domain& domain,
                                      const HloSharding& domain_sharding) {
  if (instruction->users().empty()) {
    // No sharding from users, use domain_sharding, after checking
    // compatibility.
    TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) &&
                 ShapeUtil::GetLeafCount(instruction->shape()) ==
                     domain_sharding.tuple_elements().size());
    instruction->set_sharding(domain_sharding);
    return true;
  }
  AssignmentKind assigned = AssignmentKind::kUnassigned;
  // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple
  // subshardings can result in a final sharding assignment containing
  // kUnassignedDevice leaves, in case some tuple indexes are not used, or are
  // used by users that don't have a sharding.
  // Non-tuple shardings are either assigned to a real sharding, or are not
  // assigned at all. As such they will never get assigned to kUnassignedDevice.
  // In any case, kUnassignedDevice is never propagated, from the implementation
  // of AssignLeafSharding.
  ShapeTree<HloSharding> sharding_tree(
      instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
  for (HloInstruction* user : instruction->users()) {
    if (user->opcode() == HloOpcode::kDomain &&
        domain.exit_domains.count(const_cast<HloInstruction*>(user)) > 0) {
      // If a user is a domain and it is registered in the domain exits, then
      // the instruction sharding is taken directly from the domain, and no
      // further users need to be visited.
      instruction->set_sharding(domain_sharding);
      return true;
    }
    if (!user->has_sharding()) {
      continue;
    }
    AssignmentKind sub_assigned = AssignmentKind::kUnassigned;
    ShapeTree<HloSharding> user_sharding_tree =
        GetShardingTreeFromUser(*instruction, *user);
    if (ShapeUtil::IsTuple(instruction->shape())) {
      // For tuple-shaped instructions collect individual tuple subshardings
      // from the uses, and then combine them into the tuple sharding.
      // If the user is a GTE its sharding concerns only the subtree of
      // sharding_tree at index user->tuple_index, otherwise the whole
      // sharding_tree is affected.
      ShapeTree<HloSharding>::iterator sharding_tree_begin =
          user->opcode() == HloOpcode::kGetTupleElement
              ? sharding_tree.find({user->tuple_index()})
              : sharding_tree.begin();
      TF_ASSIGN_OR_RETURN(
          sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin,
                                           user_sharding_tree));
    } else {
      // Non-tuple shape: assign common users sharding.
      TF_RET_CHECK(user_sharding_tree.leaf_count() == 1)
          << "Expected non-tuple user sharding";
      TF_ASSIGN_OR_RETURN(
          sub_assigned,
          AssignTreeSharding(&sharding_tree, sharding_tree.begin(),
                             user_sharding_tree));
    }

    if (sub_assigned == AssignmentKind::kConflict) {
      // In case of conflict we don't assign any sharding.
      return false;
    } else if (sub_assigned == AssignmentKind::kAssigned) {
      assigned = sub_assigned;
    }
  }

  if (assigned == AssignmentKind::kAssigned) {
    if (ShapeUtil::IsTuple(instruction->shape())) {
      instruction->set_sharding(HloSharding::Tuple(sharding_tree));
    } else {
      TF_RET_CHECK(sharding_tree.leaf_count() == 1);
      instruction->set_sharding(sharding_tree.leaf_begin()->second);
    }
    return true;
  }
  return false;
}

// Tries to propagate the sharding information into the instructions that are
// part of the domain, in a reverse post order manner (users propoagate to
// instruction).
StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
                                        const HloSharding& domain_sharding) {
  int64 assigned = 0;
  // domain.instructions are ordered in a post-order manner. As we do
  // user->operand propagation we process instructions in reverse order. In so
  // doing we are guaranteed to process all users before their operands.
  for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend();
       ++it) {
    HloInstruction* instruction = *it;
    if (instruction->has_sharding()) {
      continue;
    }
    // Take the sharding from the users.
    TF_ASSIGN_OR_RETURN(
        bool instruction_assigned,
        ApplyShardingFromUsers(instruction, domain, domain_sharding));
    if (instruction_assigned) {
      ++assigned;
      VLOG(4) << "  " << instruction->name() << " to sharding "
              << instruction->sharding();
    }
  }
  return assigned;
}

Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
                           const HloSharding& sharding) {
  // None of the external normalizers handled the domain sharding, try to see
  // whether this is a single sharding first.
  auto single_sharding = sharding.ExtractSingleSharding();
  if (single_sharding) {
    // Shortcut the simple case. We have a unique sharding, so we call
    // the ApplyDomainSingleSharding() API which will apply array or tuple
    // shaped sharding to the domain instructions.
    return ApplyDomainSingleSharding(domain, *single_sharding);
  }
  VLOG(1) << "Assigning non-trivial sharding " << sharding;
  TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status());

  int64 unassigned = 0;
  for (HloInstruction* instruction : domain.instructions) {
    if (!instruction->has_sharding()) {
      LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
      ++unassigned;
    } else {
      // Un-set sharding of tuples whose sub-sgardings are assigned to
      // kUnassignedDevice. Indeed in case of doubt it is better to leave the
      // entire tuple unassigned, and let the device placer decide for it.
      if (instruction->sharding().UsesDevice(kUnassignedDevice)) {
        TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()))
            << "Only tuples can have kUnassignedDevice sub shardings";
        instruction->clear_sharding();
      }
    }
  }
  // Should we error out if unassigned > 0?
  return Status::OK();
}

StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
    absl::Span<HloInstruction* const> instructions) {
  // If we are here, all the instructions being passed had the same sharding
  // (or no sharding), by the means of the ShardingMatches() API.
  // As such, no kDomain was inserted, and here we are asked to extract the
  // original common sharding.
  // All the instructions passed to this API are part of the same computation.
  std::shared_ptr<const HloSharding> sharding;
  for (HloInstruction* instruction : instructions) {
    if (instruction->has_sharding()) {
      if (sharding == nullptr) {
        sharding = instruction->sharding_ptr();
      } else {
        TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
            << "Sharding " << *sharding << " does not match the one in "
            << instruction->ToString();
      }
    }
  }
  if (sharding == nullptr) {
    return std::shared_ptr<const HloSharding>();
  }
  VLOG(4) << "Extracted sharding is " << *sharding;
  return CloneShardingForDomain(sharding);
}

}  // namespace

std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
  std::unique_ptr<HloSharding> sharding;
  if (sharding_ != nullptr) {
    sharding = absl::make_unique<HloSharding>(*sharding_);
  }
  return absl::make_unique<ShardingMetadata>(std::move(sharding));
}

bool ShardingMetadata::Matches(const DomainMetadata& other) const {
  const ShardingMetadata* other_ptr =
      dynamic_cast<const ShardingMetadata*>(&other);
  if (other_ptr == nullptr) {
    // If other is not a ShardingMetadata, then it is clearly a no match.
    return false;
  }
  if (sharding_ == nullptr) {
    return other_ptr->sharding_ == nullptr;
  }
  return other_ptr->sharding_ != nullptr
             ? ShardingMatches(*sharding_, *other_ptr->sharding_)
             : false;
}

size_t ShardingMetadata::Hash() const {
  if (sharding_ != nullptr) {
    return sharding_->Hash();
  }
  return static_cast<size_t>(0x297814aaad196e6dULL);
}

string ShardingMetadata::ToString() const {
  return sharding_ != nullptr ? sharding_->ToString() : "{}";
}

/*static*/ StatusOr<const ShardingMetadata*>
ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) {
  if (metadata->Kind() != ShardingMetadata::KindName()) {
    return Status(
        tensorflow::error::INVALID_ARGUMENT,
        "ShardingMetadata normalizer called with incorrect domain metadata");
  }
  return static_cast<const ShardingMetadata*>(metadata);
}

Status ShardingMetadata::NormalizeShardingDomain(
    const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
  if (metadata != nullptr) {
    TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
                        ToShardingMetadata(metadata));
    const HloSharding* sharding = sharding_metadata->sharding();
    if (sharding != nullptr) {
      VLOG(4) << "Normalizing sharding to " << sharding->ToString() << ":";
      TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
      TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
    }
  } else {
    TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding,
                        ExtractOriginalCommonSharding(domain.instructions));
    if (sharding != nullptr) {
      VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
      TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
    } else {
      VLOG(1) << "Unable to find common sharding";
    }
  }
  return Status::OK();
}

// Creates a kDomain instruction to be placed between instruction and operand.
// The kDomain instruction will be created only if the sharding differ between
// the instruction and the operand.
HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction,
                                                  HloInstruction* root,
                                                  HloInstruction* operand) {
  auto instruction_sharding = instruction->sharding_ptr();
  auto root_sharding = root->sharding_ptr();
  // No need for domain if they both have no sharding.
  if (instruction_sharding == nullptr && root_sharding == nullptr) {
    return nullptr;
  }
  // No need for domain if they match.
  if (instruction_sharding != nullptr && root_sharding != nullptr &&
      ShardingMatches(*instruction_sharding, *root_sharding)) {
    return nullptr;
  }

  if (instruction_sharding != nullptr) {
    instruction_sharding = CloneShardingForDomain(instruction_sharding);
  }
  if (root_sharding != nullptr) {
    root_sharding = CloneShardingForDomain(root_sharding);
  }

  auto it = domain_cse_map_.find({operand, instruction_sharding});
  if (it != domain_cse_map_.end()) {
    return it->second;
  }

  VLOG(3) << "Creating domain:";
  VLOG(3) << "  Instruction: " << instruction->name();
  VLOG(3) << "  Operand: " << operand->name();
  VLOG(3) << "    User side sharding: "
          << (instruction_sharding != nullptr ? instruction_sharding->ToString()
                                              : "None");
  VLOG(3) << "    Operand side sharding: "
          << (root_sharding != nullptr ? root_sharding->ToString() : "None");

  HloInstruction* domain =
      operand->parent()->AddInstruction(HloInstruction::CreateDomain(
          operand->shape(), operand,
          absl::make_unique<ShardingMetadata>(root_sharding),
          absl::make_unique<ShardingMetadata>(instruction_sharding)));
  domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding},
                          domain);
  return domain;
}

bool ShardingDomainCreator::DomainCseMapKey::operator==(
    const ShardingDomainCreator::DomainCseMapKey& other) const {
  if (instruction != other.instruction) {
    return false;
  }
  if (sharding == nullptr && other.sharding == nullptr) {
    return true;
  }
  if (sharding == nullptr || other.sharding == nullptr) {
    return false;
  }
  return *sharding == *other.sharding;
}

size_t ShardingDomainCreator::DomainCseMapHasher::operator()(
    const ShardingDomainCreator::DomainCseMapKey& key) const {
  return tensorflow::Hash64Combine(
      std::hash<const HloInstruction*>{}(key.instruction),
      key.sharding ? key.sharding->Hash()
                   : static_cast<size_t>(0x297814aaad196e6dULL));
}

}  // namespace xla