aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/graph_analyzer/subgraph.h
blob: 4de31d5dfa2a03dbf0adeb3f0732d59c6d86da00 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_

#include <initializer_list>
#include <set>

#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
#include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
#include "tensorflow/core/lib/gtl/flatset.h"

namespace tensorflow {
namespace grappler {
namespace graph_analyzer {

// The description of a single subgraph for processing.
class Subgraph {
 public:
  // Identity of a single subgraph as a set of nodes.
  class Identity : public gtl::FlatSet<const GenNode*> {
   public:
    using InitializerList = std::initializer_list<GenNode*>;

    Identity() = default;
    Identity(InitializerList init);
    bool operator<(const Identity& other) const;
    bool operator==(const Identity& other) const;

    // Compute the hash.
    size_t Hash() const;
  };

  explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {}

  // Construct by extending the parent identity with an extra node.
  Subgraph(const Identity& parent_id, GenNode* add_node);

  Subgraph() = delete;
  Subgraph(const Subgraph& other) = delete;
  void operator=(const Subgraph& other) = delete;

  // Order for building sets of subgraphs.
  bool operator<(const Subgraph& other) const { return this->id_ < other.id_; }
  // Support for hashed sets.
  bool operator==(const Subgraph& other) const {
    return this->id_ == other.id_;
  }
  size_t Hash() const { return hash_; }

  // Dump the subgraph information to a string.
  string Dump();

  // Extract this subgraph into a separate graph representation for signature
  // building, that includes only the links between the nodes in the subgraph
  // and drops all the external links. The result map should be clear before the
  // call.
  void ExtractForSignature(SigNodeMap* result);

  const Identity& id() const { return id_; }
  bool specific() const { return specific_; }
  void SetSpecific(bool value) { specific_ = value; }
  int32_t collation_count() const { return collation_count_; }
  void AddCollation(int32_t n = 1) { collation_count_ += n; }
  void ResetCollation() { collation_count_ = 1; }
  void MergeCollation(const Subgraph& other) {
    collation_count_ += other.collation_count_;
  }

 private:
  // Identity also serves as the list of nodes. It never changes throughout the
  // life of subgraph.
  Identity id_;
  size_t hash_;  // Cached from the identity.
  // Whether the dump should include the specific names of the nodes. The
  // non-specific (i.e. generic) subgraphs represent a collation of multiple
  // subgraphs.
  bool specific_ = true;
  // How many collated subgraphs are represented by this subgraph.
  int32_t collation_count_ = 1;
};

// Iteration of all links in a subgraph. This is more like Java iterators than
// the normal C++ iterators. It's simpler this way and there seems to be no
// major reason to make it a proper C++ iterator.
class SubgraphIterator {
 public:
  // Obviously an iterator is valid only until the original object
  // gets destroyed.
  explicit SubgraphIterator(const Subgraph::Identity* id);
  explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {}

  // Check whether the built-in iterator is at the end.
  bool AtEnd() const { return id_it_ == id_->end(); }

  // Get the neighbor at the current iterator.
  // MUST NOT be called when AtEnd();
  const GenNode::LinkTarget& GetNeighbor() const {
    return link_map_it_->second[link_idx_];
  }

  // Get the node at the current iterator.
  // MUST NOT be called when AtEnd();
  const GenNode* GetNode() const { return *id_it_; }

  // Get the port leading to the neighbor at the current iterator.
  // MUST NOT be called when AtEnd();
  GenNode::Port GetPort() const { return link_map_it_->first; }

  // Increases the iterator.
  // Returns true if NOT AtEnd() after increasing the iterator.
  // Safe to call if already AtEnd().
  bool Next();

  // If there are more links at the same port, increases the iterator and
  // returns true. Otherwise leaves the iterator unchanged and returns false.
  bool NextIfSamePort();

  // Increases the iterator directly to the last position on the current port
  // (or if already there then doesn't increase). Equivalent to calling
  // NextIfSamePort() while it returns true, but faster.
  // Safe to call if already AtEnd().
  void SkipPort();

  // Increases the iterator directly to the last position on the current node.
  // Safe to call if already AtEnd().
  void SkipNode();

  // Returns true if the iterators are exactly the same.
  bool operator==(const SubgraphIterator& other) const;
  bool operator!=(const SubgraphIterator& other) const {
    return !(*this == other);
  }

 private:
  // After link_idx_ has been increased, make sure that it points to the
  // next valid element (or end) by increasing the higher levels of iteration if
  // needed.
  // Returns true if NOT AtEnd() after increasing the iterator.
  // NOT safe to call if already AtEnd().
  bool PropagateNext();

  // Identity of the subgraph being iterated over.
  const Subgraph::Identity* id_;

  // The current position, allowing to iterate through the links (see the
  // reasoning for it in the public section).
  //
  // (1) Iterator of the nodes in the subgraph.
  Subgraph::Identity::const_iterator id_it_;
  // (2) Iterator in the link map of the node.
  GenNode::LinkMap::const_iterator link_map_it_;
  // (3) Index in the vector of the links.
  int32_t link_idx_;
};

// A convenient way to store subgraphs: in a set of unique_ptrs. This way the
// addresses of subgraph objects will stay stable, and the objects themselves
// won't be copied.
class SubgraphPtrSet
    : public std::unordered_set<std::unique_ptr<Subgraph>,
                                HashAtPtr<std::unique_ptr<Subgraph>>,
                                EqAtPtr<std::unique_ptr<Subgraph>>> {
 public:
  // Attempts to extend the set by adding a new subgraph that gets created by
  // adding one node to the parent subgraph. If such a subgraph already exists,
  // returns nullptr, otherwise returns the pointer to the new subgraph.
  Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node);
};

}  // end namespace graph_analyzer
}  // end namespace grappler
}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_