aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/graph_analyzer/sig_node.h
blob: 45c0ed31626ec99d1c443313f9b4d6ef9a6fa43a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_

#include <map>
#include <memory>
#include <vector>

#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"

namespace tensorflow {
namespace grappler {
namespace graph_analyzer {

namespace test {
class SigBaseTest;
}  // end namespace test

class SigNode;

// To find nodes by name. Having the map ordered makes the tests easier,
// and it isn't used in production code often enough to get any win from
// using an unordered map.
using SigNodeMap = std::map<string, std::unique_ptr<SigNode>>;

// One node in the graph, in the form convenient for generation of the signature
// of the graph, and comparison of two (sub)graphs for equivalence. It refers to
// the original NodeDef protobuf for most information and adds the extra
// enrichment.
//
// The graph building is 2-stage: first match a SigNode with each NodeDef and
// collect them into a map that finds them by name, then process the map,
// deep-parse the underlying NodeDefs and connect the SigNodes together.
class SigNode {
 public:
  friend struct Signature;

  // Will keep the pointer to the underlying NodeDef, so that
  // underlying object must not be deleted while SigNode is alive.
  explicit SigNode(const NodeDef* node);

  // Access wrappers.
  const string& name() const { return node_->name(); }
  const string& opcode() const { return node_->op(); }
  const NodeDef* node_def() const { return node_; }

  // For extraction of subgraphs into a separate SigNodeMap, copies the links
  // that point inside the subgraph from a full-graph SigNode to a subgraph
  // SigNode. The translation map defines the subgraph and gives the mapping
  // from the nodes in the full graph to the matching nodes in subgraph.
  using TranslationMap =
      std::unordered_map<const GenNode* /*full_graph*/, SigNode* /*subgraph*/>;
  void CopyLinks(const GenNode& from, const TranslationMap& map);

  // A link is an edge of the graph that connects 2 nodes. Each of the connected
  // nodes has its own perspective on the link, seeing its local port, remote
  // port and the remote node. The direction of the link is encoded in the
  // ports, one port is always incoming and another one outgoing.
  //
  // The link tag here contains both ports of the link viewed from the
  // perspective of this node; consisting of both the local port (i.e. at this
  // node) and remote port (i.e. on the other node), the local one going first.
  struct LinkTag {
    struct Hasher {
      size_t operator()(const LinkTag& tag) const noexcept {
        size_t hval = port_hasher(tag.local);
        CombineHash(port_hasher(tag.remote), &hval);
        return hval;
      }
      GenNode::Port::Hasher port_hasher;
    };

    LinkTag(GenNode::Port a_local, GenNode::Port a_remote)
        : local(a_local), remote(a_remote) {}

    // The default constructor is used for the default values in maps.
    // (false, 99) is an arbitrary value that makes the uninitialized
    // links easy to tell when debugging (they should never happen).
    LinkTag() : local(false, 99), remote(false, 99) {}

    // Port of the link on the local node.
    GenNode::Port local;
    // Port of the link on the remote node.
    GenNode::Port remote;

    bool operator==(const LinkTag& other) const {
      return local == other.local && remote == other.remote;
    }
    bool operator<(const LinkTag& other) const {
      return local < other.local ||
             (local == other.local && remote < other.remote);
    }
  };

  // Since the signature logic doesn't differentiate between the links
  // with the same tag (other than by the "peer" nodes on their other ends),
  // all the links with the same tag are grouped into a single structure.
  struct Link {
    LinkTag tag;
    size_t unique_hash;  // Hash of the tag after conflict resolution.
    // The remote node(s) on the other side on the link(s).
    using PeerVector = std::vector<SigNode*>;
    PeerVector peers;
  };

  // A way to look up the link description by its hash.
  using LinkHashMap = std::map<size_t, Link>;
  const LinkHashMap& hash_to_link() const { return hash_to_link_; }

  // The enumeration of all the peer nodes in a predictable order.
  // Before the signature generation, only the link values determine the
  // order, after the signature generation the entries at the same
  // links get further sorted by their peer node ranks.
  struct HashedPeer {
    HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {}

    struct LessByRank {
      bool operator()(const SigNode::HashedPeer& left,
                      const SigNode::HashedPeer& right) {
        return left.peer->unique_rank_ < right.peer->unique_rank_;
      }
    };

    size_t link_hash;
    SigNode* peer;
  };
  using HashedPeerVector = std::vector<HashedPeer>;
  const HashedPeerVector& hashed_peers() const { return hashed_peers_; }

  // Compares two nodes in two different graphs for equivalence (two nodes in
  // the same graph would never be equivalent). Expects that the signatures of
  // the graphs have already been computed, so unique_rank_ is filled in and
  // the hashed_peers_ properly ordered.
  bool operator==(const SigNode& other) const;

  bool operator!=(const SigNode& other) const { return !(*this == other); }

 private:
  friend class test::SigBaseTest;

  // The CopyLinks code is split into 2 parts for testability.
  // The first pass builds a map ordered by LinkTag for predictability.
  void CopyLinksPass1(const GenNode& from, const TranslationMap& map,
                      std::map<LinkTag, Link>* link_map);
  // The second pass converts to the map by hash value,
  // resolves any hash conflicts, and builds the hashed peer vector.
  void CopyLinksPass2(std::map<LinkTag, Link>* link_map);

  // Computes the topological hash at distance 0. Resets the topo_hash_ vector
  // and hashed_nodes_;
  void ComputeTopoHash0();

  // Compute the topological has at the given distance. The hashes for all the
  // lower distances must be already computed for all the nodes in the graph.
  // Also computes next_hashed_nodes_ from last_hashed_nodes_.
  void ComputeTopoHash(int distance);

  // Get the hash value for a particular distance. It must be previously
  // computed.
  size_t GetTopoHash(int distance) const;

  // The the hash value for the highest computed distance. It must be previously
  // computed.
  size_t GetHighTopoHash() const {
    CHECK(!topo_hash_.empty());
    return topo_hash_.back();
  }

  // Rehash the topmost hash, to avoid conflicts.
  void ReHighTopoHash() {
    CHECK(!topo_hash_.empty());
    CombineHash(1, &topo_hash_.back());
  }

  // Ordering by node order and highest available hash (it must be
  // previously computed).
  struct NodeOrderLess {
    bool operator()(const SigNode* left, const SigNode* right) {
      return left->topo_hash_.back() < right->topo_hash_.back();
    }
  };

 private:
  const NodeDef* node_;

  // The bitmap mask with 1 bit set that represents this node in the set
  // during the computation of the signature.
  uint64_t node_mask_ = 0;

  // The code that populates this map makes sure that there are no hash
  // conflicts, rehashing if necessary.
  LinkHashMap hash_to_link_;

  // The enumeration of all the direct peers in the predictable order (which
  // happens to be the order ot their link tags, but the order of the hashes
  // would do too). It is used for the quick enumeration during the signature
  // computation. After the signature building is completed, the entries that
  // have the same link tag get further sorted in the order of the ranks of
  // their nodes.
  HashedPeerVector hashed_peers_;

  // The unique rank represents the order in which the node will be included
  // into the signature. It gets assigned in order either when the topo_hash_ of
  // this node becomes unique in the graph, or when the nodes are completely
  // equivalent, one of them is picked at random to assign the next rank, and
  // then the rest of the nodes attempt to disambiguate based on that
  // information.
  size_t unique_rank_ = ~0;
  // When hash_is_final_ is set, the topo_has_ vector stops growing, and the
  // last value from it is used for all the further hashes.
  bool hash_is_final_ = false;
  // The hashes that include the topology of the nodes up to the distance N. The
  // hash for distance 0 is produced from the attributes of this node itself and
  // its general connectivity properties but no information about the
  // neighboring nodes. The hash for distance D+1 is build from hashes at level
  // D of this node and of all its immediate neighbors. The neighbors that are
  // connected by equivalent links are included in a commutative way.
  std::vector<size_t> topo_hash_;
  // The set of nodes that got included into the computation of the
  // last topo_hash_ entry.
  uint64_t last_hashed_nodes_ = 0;
  // The next set of nodes that gets used for the current topo_hash entry.
  uint64_t next_hashed_nodes_ = 0;
};

// Signature of a graph. The computation is intertwined with the private methods
// of SigNode, so keeping both in the same file looks more convenient.
struct Signature {
  friend class test::SigBaseTest;

  // Maximal size of the graphs for which the signature can be computed.
  // Changing this constant won't magically add the support for a larger size,
  // the rest of implementation would have to be extended. The value of 64 is
  // driven by the size of a bitset in an uint64_t, and should be enough for our
  // purposes, while having a high efficiency of implementation.
  static constexpr int kMaxGraphSize = 64;

  // Using the map, computes the rest of the fields of a signature.
  // Returns an error is the graph is too big.
  Status Compute();

  // Convert the computed signature to a string representation.
  string ToString() const;

  SigNodeMap map;        // The nodes in the graph, accessible by name.
  size_t sig_short = 0;  // Hash of the signature, for the quick equality check.
  // The full signature: hashes of the nodes in a predictable order.
  std::vector<size_t> sig_full;
  // The nodes in the same order as they go in the signature.
  std::vector<SigNode*> nodes;

  // For building the unordered maps.
  size_t Hash() const { return sig_short; }

  // Returns true if the graphs are equivalent. The signature must be already
  // computed.
  bool operator==(const Signature& other) const;

 private:
  // Populates the nodes vector from the map and initializes the state of the
  // nodes for the signature computation.
  void PrepareNodes();

  // Finds the nodes with the hashes that are unique and assigns the unique ids
  // to them. If there are nodes with non-unique hashes, exactly one node from
  // the first such sequence (in the order of hash values) will be picked and
  // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have
  // been already assigned the unique ids. Advances next_node_id by at least 1.
  void FindUniqueHashes(size_t* next_node_id_p);

  // One round of the signature computation. Assumes that the
  // nodes[0...(next_node_id-1)] have been already assigned the fixed
  // positions, and thus computes the hashes only for the remaining nodes.
  void ComputeOneRound(size_t next_node_id);

  // Additional ordering of the hashed_peers_ links in the nodes, so that they
  // can be compared and printed in a predictable order.
  void OrderLinks();
};

}  // end namespace graph_analyzer
}  // end namespace grappler
}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_