aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_domain_map.h
blob: e62ef763fb3881ab6030b1f6a66266ac80a3d84d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_

#include <memory>
#include <vector>

#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"

namespace xla {

// The HloDomainMap splits a set of instructions within a module or computation,
// into different domains, separated by kDomain instructions.
// A domain is composed by a set of instructions which can reach each other via
// operand/user edges, without crossing a kDomain insutrction of a given kind.
// A domain never crosses computation boundaries.
class HloDomainMap {
 public:
  // Creates a new HloDomainMap, creating all the domains within the input
  // computation, of the given kind. If domain_kind is not empty, only the
  // kDomain instructions of domain_kind will be considered as separators.
  // Otherwise every kDomain instruction will be splitting domains.
  static StatusOr<std::unique_ptr<HloDomainMap>> Create(
      HloComputation* computation, string domain_kind);

  // Creates a new HloDomainMap, creating all the domains within the input
  // module, of the given kind. If domain_kind is not empty, only the
  // kDomain instructions of domain_kind will be considered as separators.
  // Otherwise every kDomain instruction will be splitting domains.
  static StatusOr<std::unique_ptr<HloDomainMap>> Create(HloModule* module,
                                                        string domain_kind);

  // Retrieves all the domains the input module or computation are composed by.
  const std::vector<std::unique_ptr<DomainMetadata::Domain>>& GetDomains()
      const {
    return instruction_domains_;
  }

  // Checks whether two instructions are within the same domain.
  bool InSameDomain(HloInstruction* instruction1,
                    HloInstruction* instruction2) const;

  // Checks whether instruction is a kDomain instruction of the kind we are
  // currently processing.
  bool IsDomainInstruction(HloInstruction* instruction) const;

 private:
  HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}

  // Check if the kDomain instruction is facing (via its operand link) another
  // kDomain instruction of the same kind, hence defining an empty domain.
  // If that is the case, create the empty domain and call the proper
  // normalizer.
  Status TryProcessEmptyDomain(HloInstruction* instruction);

  Status Populate(HloComputation* computation);

  // Inserts the provided domain into the ones tracked by this object,
  // creating a new domain ID.
  Status InsertDomain(std::unique_ptr<DomainMetadata::Domain> domain);

  // From the given instruction, epxands operand and user wise, the set of
  // instructions which can be reached without crossing a kDomain instruction
  // of the kind specified by domain_kind_.
  // The domain data structure will be populated with all the reached
  // instructions, and the boundaries of the domain, with the kDomain
  // instructions encountered while expanding the reach.
  Status ExpandDomain(HloInstruction* instruction,
                      DomainMetadata::Domain* domain) const;

  // Creates a domain data structure using the ExpandDomain() API.
  StatusOr<std::unique_ptr<DomainMetadata::Domain>> CreateDomain(
      HloInstruction* instruction) const;

  // Out of an instruction set, returns a vector of all the ones which are not
  // a kDomain kind.
  static std::vector<HloInstruction*> MakeNonDomainInstructions(
      const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set);

  string domain_kind_;
  std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
  tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
};

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_