aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_util.h
blob: f21b44bcd98d77b831de5d8a6afa4f9ddd91d15d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_

#include <functional>
#include <memory>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"

namespace xla {

// Collection of utilities for handling HloModuleGroups.
class HloModuleGroupUtil {
 public:
  explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata)
      : metadata_(metadata) {}

  // Returns all unique predecessors of the instruction. This includes:
  // * predecessors in the same computation: operands and control predecessors
  // * Recv is a predecessor of Send
  // * Send is a predecessor of RecvDone
  // * predecessors of companions (if the instruction is a companion while)
  // * predecessors' companions (for any predecessor that is a companion while)
  std::vector<HloInstruction*> GlobalPredecessors(HloInstruction* instruction);

  // Returns all unique successors of the instruction. This includes:
  // * successors in the same computation: users and control successors
  // * Send is a successor of Recv
  // * RecvDone is a predecessor of Send
  // * successors of companions (if the instruction is a companion while)
  // * successors' companions (for any successor that is a companion while)
  std::vector<HloInstruction*> GlobalSuccessors(HloInstruction* instruction);

  // Returns the root instructions of the computations.
  std::vector<HloInstruction*> RootInstructions(
      absl::Span<HloComputation* const> computations);

  // Visit state of each instruction during DFS traversal.
  enum VisitState {
    kNotVisited = 0,
    kVisiting,
    kVisited,
  };

  // Function called on each instruction group during the DFS traversal. See the
  // comment for VisitTopologicalOrder()).
  using VisitFunction = std::function<Status(
      HloInstruction* hlo,
      const std::vector<HloInstruction*>& instruction_group)>;

  // Given the hlo instruction as the root, recursively visits all its
  // predecessor instructions in DFS order to visit nodes in topological order.
  //
  // Note that the DFS traversal does not only visit nodes in the same
  // computation (parent of the root instruction), but also visits nodes in
  // different computations connected via communication instructions. During the
  // traversal, companion While instructions (see the class comment in
  // HloModuleGroupMetadata) are treated as a single instruction (called
  // instruction group, which contains only a single instruction if the visiting
  // node is not a companion while) -- visiting one of the instructions in the
  // group effectively visits all other instructions in the group, and then all
  // predecessor instructions of the group are visited.
  //
  // * visit_state: map from each instruction to its visit state.
  // * visit_function: function called when each instruction group.
  // * root: the root instruction of the traversal.
  using VisitStates = absl::flat_hash_map<HloInstruction*, VisitState>;
  Status VisitTopologicalOrder(VisitStates* visit_state,
                               const VisitFunction& visit_function,
                               HloInstruction* root);

  // Verifies that the computations are well-formed (e.g., no cycles).
  Status VerifyComputations(absl::Span<HloComputation* const> computations);

  // Below Reachability utils resemble those in HloComputation, except that
  // they can handle instructions across multiple computations.
  //
  // Creates the reachability map for the instructions in the computations.
  StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability(
      absl::Span<HloComputation* const> computations);

  // Updates the reachability of the given instruction, taking the global
  // predeccessorss and successors into account.
  void UpdateReachabilityThroughInstruction(
      HloInstruction* instruction, HloReachabilityMap* reachability_map);

 private:
  const HloModuleGroupMetadata& metadata_;
};

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_