aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/edgeset.h
blob: df0d78b8fbbf429ee7c30ad95d948d8958262c62 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#ifndef TENSORFLOW_GRAPH_EDGESET_H_
#define TENSORFLOW_GRAPH_EDGESET_H_

#include <stddef.h>
#include <set>
#include "tensorflow/core/platform/port.h"

#include "tensorflow/core/platform/logging.h"
namespace tensorflow {

class Edge;

// An unordered set of edges.  Uses very little memory for small sets.
// Unlike std::set, EdgeSet does NOT allow mutations during iteration.
class EdgeSet {
 public:
  EdgeSet();
  ~EdgeSet();

  typedef const Edge* key_type;
  typedef const Edge* value_type;
  typedef size_t size_type;
  typedef ptrdiff_t difference_type;

  class const_iterator;
  typedef const_iterator iterator;

  bool empty() const;
  size_type size() const;
  void clear();
  std::pair<iterator, bool> insert(value_type value);
  size_type erase(key_type key);

  // Caller is not allowed to mutate the EdgeSet while iterating.
  const_iterator begin() const;
  const_iterator end() const;

 private:
  // Up to kInline elements are stored directly in ptrs_ (nullptr means none).
  // If ptrs_[0] == this then ptrs_[1] points to a set<const Edge*>.
  static const int kInline = 2;  // Must be >= 2.
  const void* ptrs_[kInline];

  std::set<const Edge*>* get_set() const {
    if (ptrs_[0] == this) {
      return static_cast<std::set<const Edge*>*>(const_cast<void*>(ptrs_[1]));
    } else {
      return nullptr;
    }
  }

// To detect mutations while iterating.
#ifdef NDEBUG
  void RegisterMutation() {}
#else
  uint32 mutations_ = 0;
  void RegisterMutation() { mutations_++; }
#endif

  TF_DISALLOW_COPY_AND_ASSIGN(EdgeSet);
};

class EdgeSet::const_iterator {
 public:
  typedef typename EdgeSet::value_type value_type;
  typedef const typename EdgeSet::value_type& reference;
  typedef const typename EdgeSet::value_type* pointer;
  typedef typename EdgeSet::difference_type difference_type;
  typedef std::forward_iterator_tag iterator_category;

  const_iterator() {}

  const_iterator& operator++();
  const_iterator operator++(int /*unused*/);
  const value_type* operator->() const;
  value_type operator*() const;
  bool operator==(const const_iterator& other) const;
  bool operator!=(const const_iterator& other) const {
    return !(*this == other);
  }

 private:
  friend class EdgeSet;

  void const* const* array_iter_ = nullptr;
  typename std::set<const Edge*>::const_iterator tree_iter_;

#ifdef NDEBUG
  inline void Init(const EdgeSet* e) {}
  inline void CheckNoMutations() const {}
#else
  inline void Init(const EdgeSet* e) {
    owner_ = e;
    init_mutations_ = e->mutations_;
  }
  inline void CheckNoMutations() const {
    CHECK_EQ(init_mutations_, owner_->mutations_);
  }
  const EdgeSet* owner_ = nullptr;
  uint32 init_mutations_ = 0;
#endif
};

inline EdgeSet::EdgeSet() {
  for (int i = 0; i < kInline; i++) {
    ptrs_[i] = nullptr;
  }
}

inline EdgeSet::~EdgeSet() { delete get_set(); }

inline bool EdgeSet::empty() const { return size() == 0; }

inline EdgeSet::size_type EdgeSet::size() const {
  auto s = get_set();
  if (s) {
    return s->size();
  } else {
    size_t result = 0;
    for (int i = 0; i < kInline; i++) {
      if (ptrs_[i]) result++;
    }
    return result;
  }
}

inline void EdgeSet::clear() {
  RegisterMutation();
  delete get_set();
  for (int i = 0; i < kInline; i++) {
    ptrs_[i] = nullptr;
  }
}

inline EdgeSet::const_iterator EdgeSet::begin() const {
  const_iterator ci;
  ci.Init(this);
  auto s = get_set();
  if (s) {
    ci.tree_iter_ = s->begin();
  } else {
    ci.array_iter_ = &ptrs_[0];
  }
  return ci;
}

inline EdgeSet::const_iterator EdgeSet::end() const {
  const_iterator ci;
  ci.Init(this);
  auto s = get_set();
  if (s) {
    ci.tree_iter_ = s->end();
  } else {
    ci.array_iter_ = &ptrs_[size()];
  }
  return ci;
}

inline EdgeSet::const_iterator& EdgeSet::const_iterator::operator++() {
  CheckNoMutations();
  if (array_iter_ != nullptr) {
    ++array_iter_;
  } else {
    ++tree_iter_;
  }
  return *this;
}

inline EdgeSet::const_iterator EdgeSet::const_iterator::operator++(
    int /*unused*/) {
  CheckNoMutations();
  const_iterator tmp = *this;
  operator++();
  return tmp;
}

// gcc's set and multiset always use const_iterator since it will otherwise
// allow modification of keys.
inline const EdgeSet::const_iterator::value_type* EdgeSet::const_iterator::
operator->() const {
  CheckNoMutations();
  if (array_iter_ != nullptr) {
    return reinterpret_cast<const value_type*>(array_iter_);
  } else {
    return tree_iter_.operator->();
  }
}

// gcc's set and multiset always use const_iterator since it will otherwise
// allow modification of keys.
inline EdgeSet::const_iterator::value_type EdgeSet::const_iterator::operator*()
    const {
  CheckNoMutations();
  if (array_iter_ != nullptr) {
    return static_cast<value_type>(*array_iter_);
  } else {
    return *tree_iter_;
  }
}

inline bool EdgeSet::const_iterator::operator==(
    const const_iterator& other) const {
  DCHECK((array_iter_ == nullptr) == (other.array_iter_ == nullptr))
      << "Iterators being compared must be from same set that has not "
      << "been modified since the iterator was constructed";
  CheckNoMutations();
  if (array_iter_ != nullptr) {
    return array_iter_ == other.array_iter_;
  } else {
    return other.array_iter_ == nullptr && tree_iter_ == other.tree_iter_;
  }
}

}  // namespace tensorflow

#endif  // TENSORFLOW_GRAPH_EDGESET_H_