aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.h
blob: 95126d470c6aa3a787614448c722cc8e414f82ed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_H_

#include <functional>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"

namespace tensorflow {
namespace grappler {

// A utility class to lookup a node and its outputs by node name.
class NodeMap {
 public:
  // Note: The NodeMap will store pointers to nodes in graph, which may become
  // invalid if graph is changed.
  explicit NodeMap(GraphDef* graph);
  NodeDef* GetNode(const string& name) const;
  bool NodeExists(const string& name) const;
  const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
  // This method doesn't record the outputs of the added node; the outputs need
  // to be explicitly added by the AddOutput method.
  void AddNode(const string& name, NodeDef* node);
  void RemoveNode(const string& name);
  void UpdateInput(const string& node_name, const string& old_input_name,
                   const string& new_input_name);
  void AddOutput(const string& node_name, const string& output_name);
  void RemoveInputs(const string& node_name);
  void RemoveOutput(const string& node_name, const string& output_name);
  void RemoveOutputs(const string& node_name);
  void UpdateOutput(const string& node_name, const string& old_output_name,
                    const string& new_output_name);

 private:
  const std::set<NodeDef*> empty_set_;
  std::unordered_map<string, NodeDef*> nodes_;
  std::unordered_map<string, std::set<NodeDef*>> outputs_;
};

// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
template <class T, class Hash = std::hash<T>>
class SetVector {
 public:
  // Returns false if value already existed in the set, true otherwise.
  bool PushBack(const T& value) {
    if (!set_.insert(value).second) {
      return false;
    }
    vector_.push_back(value);
    return true;
  }

  T PopBack() {
    T back = vector_.back();
    set_.erase(back);
    vector_.pop_back();
    return back;
  }

  bool Exists(const T& value) const { return set_.find(value) != set_.end(); }

  bool Empty() const { return vector_.empty(); }

  void Reserve(int64 size) { vector_.reserve(size); }

 private:
  std::unordered_set<T, Hash> set_;
  std::vector<T> vector_;
};

// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);

// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);

// Returns the trailing position number (or zero if no number is present) if
// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
// Returns -2 if NodeName(input_name) is not equal to node_name.
// Note: This function is used very heavily, and this hand-optimized
// version is 3-4x faster than the version using Scanner, which it replaced.
// This is worth the reduction in readability.
inline int NodePositionIfSameNode(const string& input_name,
                                  const string& node_name) {
  if (input_name.empty()) return -2;
  const bool is_ctrl = input_name[0] == '^';
  auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
  auto node_it = node_name.begin();
  if (node_name.empty() ||
      std::distance(input_it, input_name.end()) < node_name.size()) {
    return -2;
  }
  while (node_it != node_name.end()) {
    if (*input_it++ != *node_it++) {
      return -2;
    }
  }
  if (input_it == input_name.end()) {
    return is_ctrl ? -1 : 0;
  } else if (*input_it++ == ':') {
    StringPiece remaining(&(*input_it),
                          std::distance(input_it, input_name.end()));
    int position;
    if (!strings::safe_strto32(remaining, &position)) {
      return -2;
    }
    return is_ctrl ? -1 : position;
  } else {
    return -2;
  }
}

// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
inline StringPiece NodeNameAsStringPiece(const string& name) {
  static const string empty;
  if (name.empty()) return StringPiece(empty);
  const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
  auto end_it = begin_it;
  while (end_it != name.end() && *end_it != ':') {
    ++end_it;
  }
  if (end_it != name.end() && *end_it != ':') {
    return StringPiece(empty);
  }
  return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
}

// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
inline string NodeName(const string& name) {
  return string(NodeNameAsStringPiece(name));
}

// Returns the node name and position in a single call.
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
                                              int* position) {
  static const string empty;
  if (name.empty()) {
    *position = 0;
    return StringPiece(empty);
  }
  const bool is_ctrl = name[0] == '^';
  const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
  *position = is_ctrl ? -1 : 0;
  auto end_it = begin_it;
  while (end_it != name.end() && *end_it != ':') {
    ++end_it;
  }
  const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
  if (end_it != name.end()) {
    if (*end_it != ':') {
      return StringPiece(empty);
    } else if (!is_ctrl) {
      ++end_it;
      StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
      if (!strings::safe_strto32(remaining, position)) {
        return StringPiece(empty);
      }
    }
  }
  return node_name;
}

// Returns the node name and position in a single call.
inline string ParseNodeName(const string& name, int* position) {
  return string(ParseNodeNameAsStringPiece(name, position));
}

inline int NodePosition(const string& name) {
  int position;
  ParseNodeNameAsStringPiece(name, &position);
  return position;
}

// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
                           const string& delimiter);

// Add a prefix to a node name.
string AddPrefixToNodeName(const string& name, const string& prefix);

// Executes a 'fn' in the 'thread_pool'. The method waits for the configured
// timeout (in milliseconds) for 'fn' to complete, before returning false.
//
// If returning false, the 'fn' may still continue to execute in the
// thread-pool. It is the responsibility of the caller to reset the thread-pool
// as appropriate.
bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
                        thread::ThreadPool* thread_pool);

// Returns the node name prefixed with conventional symbol '^'
// for control dependency, given a NodeDef.
string AsControlDependency(const NodeDef& node);

// Returns the node name prefixed with conventional symbol '^'
// for control dependency, given a node name
string AsControlDependency(const string& node);

// Returns the number of outputs of a node according to its OpDef. Note that
// some of the outputs may be unconnected.
int NumOutputs(const NodeDef& node, GraphDef* graph);

// Returns true iff the node has at least one control input.
bool HasControlInputs(const NodeDef& node);

// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);

// Number of connected non-control outputs.
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);

// Number of connected non-control data outputs (Ops that consume output tensor
// data, not just it's shape).
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);

// Removes redundant control inputs from node.
void DedupControlInputs(NodeDef* node);

// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name);

// Returns the last node in the simple chain starting at source and traversing
// through the input(0) edge from each node as long as the next node satisfies
// the predicate given in pred_fn. If no nodes satisfy the predicate, &source
// will be returned. Example: For the chain
//    source <- a <- b <- ... <- y <- z
// where
//    pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
//    pred_fn(z) = false,
// the return value will be a pointer to y.
NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
                        bool follow_control_input,
                        const std::function<bool(const NodeDef&)>& pred_fn);

// Permute the nodes of graph in place according to the permutation.
void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
                         bool invert_permutation);

Status SetTensorValue(DataType dtype, int value, Tensor* tensor);

void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);

void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);

void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
                         GraphDef* graph);

class SimpleGraphView {
 public:
  // Build a graph view for the specified graphdef.
  Status Initialize(const GraphDef& graph) {
    return Initialize(graph, nullptr, true, true);
  }
  // Build a graph view for the specified graphdef augmented with the additional
  // edges specified in 'extra_dependencies' if any. Note that
  // extra_dependencies can be null.
  Status Initialize(
      const GraphDef& graph,
      const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
          extra_dependencies) {
    return Initialize(graph, extra_dependencies, true, true);
  }
  Status Initialize(
      const GraphDef& graph,
      const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
          extra_dependencies,
      bool dedup_inputs, bool dedup_outputs);

  const GraphDef* graph() const { return graph_; }
  inline int num_nodes() const { return index_to_name_.size(); }
  inline bool has_node(const string& node_name) const {
    return name_to_index_.find(node_name) != name_to_index_.end();
  }
  inline const int index(const string& node_name) const {
    const auto& it = name_to_index_.find(node_name);
    DCHECK(it != name_to_index_.end());
    return it == name_to_index_.end() ? -1 : it->second;
  }
  inline const NodeDef& node(int node_idx) const {
    return graph_->node(node_idx);
  }
  inline const string& node_name(int node_idx) const {
    return index_to_name_[node_idx];
  }
  inline const gtl::InlinedVector<int, 4>& inputs(int node_idx) const {
    return inputs_[node_idx];
  }
  inline const gtl::InlinedVector<int, 2>& outputs(int node_idx) const {
    return outputs_[node_idx];
  }

  // Traverse the graph starting at `node_idx`, collecting indices of nodes
  // visited in nodes_found. If a node has an op in `op_types_to_traverse`, the
  // walk continues to its children. It is assumed that *graph_ was not modified
  // after the call to Initialize().
  // If `op_types_to_traverse` is empty the DFS will traverse any node type.
  void DepthFirstSearch(const std::unordered_set<string>& op_types_to_traverse,
                        int node_idx, std::set<int>* nodes_found) const;

  string PrintToString() const;

 private:
  const GraphDef* graph_;  // Not owned.
  std::vector<string> index_to_name_;
  std::unordered_map<string, int> name_to_index_;
  std::vector<gtl::InlinedVector<int, 4>> inputs_;
  std::vector<gtl::InlinedVector<int, 2>> outputs_;
};

}  // end namespace grappler
}  // end namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_H_