aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_constructor.cc
blob: edf71ec1684ac4528849f50bfb1516a714c8a3bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
/* Copyright 2015 Google Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/graph/graph_constructor.h"

#include <string>
#include <unordered_map>
#include <vector>

#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/optimizer_cse.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"

namespace tensorflow {

namespace {
inline bool IsMerge(const NodeDef& node_def) {
  return node_def.op() == "Merge";
}
}  // namespace

namespace {

class GraphConstructor {
 public:
  GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef,
                   Graph* g, Status* status)
      : opts_(opts), gdef_(gdef), g_(g), status_(status) {
    BuildNodeIndex();
    InitFromEdges();
    Convert();
  }

 private:
  void SetError(const string& error);
  void SetNodeError(const NodeDef& node_def, const StringPiece& message) {
    SetError(strings::StrCat("Node '", node_def.name(), "': ", message));
  }
  void BuildNodeIndex();
  void InitFromEdges();
  Node* MakeNode(const NodeDef& node_def);
  void Convert();
  // Calls SetError() and returns false if the type of the output of
  // the source of the edge can't be consumed by destination of the edge.
  // REQUIRES: edge must be a data edge, not a control edge.
  bool TypeValidateEdge(const Edge* edge);

  // From constructor
  const GraphConstructorOptions opts_;
  const GraphDef* gdef_;
  Graph* g_;
  Status* status_;

  // Mapping from node name to the index within gdef_
  struct NodeInfo {
    explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
    // std::unordered_map<> requires that we have a default constructor.
    NodeInfo() : NodeInfo(-1) {}
    int gdef_index;
    Node* node;  // nullptr until the NodeDef is converted to a Node.
  };
  // TODO(vrv): Profile this data structure to see if we should use an
  // alternative implementation of std::unordered_map.
  std::unordered_map<StringPiece, NodeInfo, StringPiece::Hasher> name_index_;

  // Index of NodeDefs in gdef_ with all inputs already converted.
  std::vector<int> ready_;

  // Mapping between index within gdef_ and the number of inputs that
  // still need to be converted.
  std::vector<int> pending_count_;

  // Mapping between index within gdef_ and the index within gdef_ of
  // all nodes it outputs to.
  std::vector<gtl::InlinedVector<int, 4>> outputs_;

  // Used in the conversion from gdef_ to g_ to represent the ith input
  // of a node.
  struct InputInfo {
    explicit InputInfo(StringPiece node_name, Node* n, int i)
        : name(node_name), node(n), index(i) {}
    StringPiece name;
    Node* node;
    int index;
  };

  // Used in the conversion from gdef_ to g_ to represent an edge from
  // the node named 'name' to node 'n'.
  struct EdgeInfo {
    explicit EdgeInfo(StringPiece name, int i1, Node* n, int i2)
        : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
    StringPiece src_name;
    int src_index;
    Node* dst_node;
    int dst_index;
  };
};

void GraphConstructor::SetError(const string& error) {
  status_->Update(errors::InvalidArgument(error));
}

void GraphConstructor::BuildNodeIndex() {
  // Initialized outside the loop for efficiency
  const char* pattern;
  if (opts_.allow_internal_ops) {
    pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*";
  } else {
    pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*";
  }
  RE2 node_name_re(pattern);

  // Validate the node names and add them to name_index_.
  for (int n = 0; n < gdef_->node_size(); ++n) {
    const NodeDef& node_def(gdef_->node(n));
    if (!RE2::FullMatch(node_def.name(), node_name_re)) {
      SetNodeError(node_def, "Node name contains invalid characters");
      return;
    }
    if (!name_index_.insert(std::make_pair(StringPiece(node_def.name()),
                                           NodeInfo(n)))
             .second) {
      SetNodeError(node_def, "Node name is not unique");
      return;
    }
    // Validate the operation's type.
    if (node_def.op().empty()) {
      SetNodeError(node_def, "Does not specify a type");
      return;
    }
    if (opts_.expect_device_spec && node_def.device().empty()) {
      SetNodeError(node_def, strings::StrCat("Missing device specification."));
      return;
    }
  }
}

void GraphConstructor::InitFromEdges() {
  const int num_nodes = gdef_->node_size();
  ready_.reserve(num_nodes);
  pending_count_.reserve(num_nodes);
  outputs_.resize(num_nodes);

  // Parse the inputs for each node.
  for (int n = 0; n < num_nodes; ++n) {
    const NodeDef& node_def(gdef_->node(n));
    if (IsMerge(node_def)) {
      // for merge only wait for one non-control input.
      int32 num_control_edges = 0;
      for (int i = 0; i < node_def.input_size(); ++i) {
        StringPiece input_name(node_def.input(i));
        if (StringPiece(input_name).starts_with("^")) {
          num_control_edges++;
        }
      }
      pending_count_.push_back(num_control_edges + 1);
    } else {
      pending_count_.push_back(node_def.input_size());
    }
    if (node_def.input_size() == 0) {
      ready_.push_back(n);
      continue;
    }
    for (int i = 0; i < node_def.input_size(); ++i) {
      StringPiece input_name = node_def.input(i);
      if (input_name.starts_with("^")) {
        // Control dependence
        input_name.remove_prefix(1);
      }
      TensorId id(ParseTensorName(input_name));
      auto iter = name_index_.find(id.first);
      if (iter == name_index_.end()) {
        SetNodeError(node_def,
                     strings::StrCat("Unknown input node ", node_def.input(i)));
        return;
      }
      outputs_[iter->second.gdef_index].push_back(n);
    }
  }
}

Node* GraphConstructor::MakeNode(const NodeDef& node_def) {
  // Add the node to the graph.
  Node* node = g_->AddNode(node_def, status_);
  if (node == nullptr) return nullptr;
  if (opts_.expect_device_spec) {
    node->set_assigned_device_name(node_def.device());
  }
  name_index_[node_def.name()].node = node;
  return node;
}

// Return the number of nodes in "g"
static int CountNodes(Graph* g) {
  int nodes = 0;
  for (Node* node : g->nodes()) {
    VLOG(1) << node;  // Dummy use to avoid compiler warning
    nodes++;
  }
  return nodes;
}

void GraphConstructor::Convert() {
  std::vector<InputInfo> inputs;
  std::vector<EdgeInfo> back_edges;
  int processed = 0;
  // Process the NodeDefs in topological order.
  while (!ready_.empty()) {
    int o = ready_.back();
    ready_.pop_back();
    ++processed;
    const NodeDef& node_def(gdef_->node(o));
    inputs.clear();
    bool in_control_dependence = false;
    bool has_data_back_edge = false;
    for (int i = 0; i < node_def.input_size(); ++i) {
      StringPiece input_name(node_def.input(i));
      if (StringPiece(input_name).starts_with("^")) {
        // A control dependence
        in_control_dependence = true;
        input_name.remove_prefix(1);
      } else {
        if (in_control_dependence) {
          SetNodeError(node_def, strings::StrCat(
                                     "Control dependencies must come after ",
                                     "regular dependencies: input ", input_name,
                                     " of source node ", node_def.name()));
          return;
        }
      }
      TensorId id(ParseTensorName(input_name));
      auto iter = name_index_.find(id.first);
      DCHECK(iter != name_index_.end());
      Node* src_node = iter->second.node;
      if (in_control_dependence) {
        inputs.push_back(InputInfo(id.first, src_node, -1));
      } else {
        if (src_node == nullptr) {
          has_data_back_edge = true;
          inputs.push_back(InputInfo(id.first, src_node, id.second));
        } else {
          if (id.second >= src_node->num_outputs()) {
            SetNodeError(
                node_def,
                strings::StrCat("Connecting to invalid output ", id.second,
                                " of source node ", id.first, " which has ",
                                src_node->num_outputs(), " outputs"));
            return;
          }
          inputs.push_back(InputInfo(id.first, src_node, id.second));
        }
      }
    }
    if (has_data_back_edge && !IsMerge(node_def)) {
      SetError(strings::StrCat(
          node_def.name(),
          " had a back edge. But only Merge can have back edges."));
      return;
    }

    Node* node = MakeNode(node_def);
    if (node == nullptr) return;

    // Add edges from inputs to *node to the graph.
    for (size_t i = 0; i < inputs.size(); ++i) {
      if (inputs[i].node == nullptr) {
        // Record this back edge, which will be added after all nodes
        // are created.
        back_edges.push_back(
            EdgeInfo(inputs[i].name, inputs[i].index, node, i));
      } else if (inputs[i].index == -1) {
        g_->AddControlEdge(inputs[i].node, node);
      } else {
        const Edge* edge =
            g_->AddEdge(inputs[i].node, inputs[i].index, node, i);
        if (!TypeValidateEdge(edge)) return;
      }
    }

    // Update pending_count_ for outputs.
    for (size_t i = 0; i < outputs_[o].size(); ++i) {
      const int output = outputs_[o][i];
      pending_count_[output]--;
      if (pending_count_[output] == 0) {
        ready_.push_back(output);
      }
    }
  }

  // Add the back edges after all nodes are created.
  for (auto e : back_edges) {
    Node* src_node = name_index_[e.src_name].node;
    if (e.src_index == -1) {
      g_->AddControlEdge(src_node, e.dst_node);
    } else {
      const Edge* edge =
          g_->AddEdge(src_node, e.src_index, e.dst_node, e.dst_index);
      if (!TypeValidateEdge(edge)) return;
    }

    VLOG(2) << "Add back edge: " << src_node->name() << " -> "
            << e.dst_node->name();
  }

  if (processed < gdef_->node_size()) {
    SetError(
        strings::StrCat(gdef_->node_size() - processed, " nodes in a cycle"));
    return;
  }

  if (status_->ok()) {
    FixupSourceAndSinkEdges(g_);

    if (opts_.optimizer_do_cse) {
      if (!back_edges.empty()) {
        LOG(WARNING) << "Not doing CSE.  We need to figure out how to handle "
                     << "loops in the CSE phase.";
      } else {
        VLOG(1) << "Starting CSE: graph of " << CountNodes(g_) << " nodes";
        OptimizeCSE(g_, opts_.cse_consider_function);
        VLOG(1) << "Finished CSE: graph of " << CountNodes(g_) << " nodes";
      }
    }
  }
}

bool GraphConstructor::TypeValidateEdge(const Edge* edge) {
  DataType src_out = edge->src()->output_type(edge->src_output());
  DataType dst_in = edge->dst()->input_type(edge->dst_input());
  if (!TypesCompatible(dst_in, src_out)) {
    SetError(strings::StrCat(
        "Input ", edge->dst_input(), " of node ", edge->dst()->name(),
        " was passed ", DataTypeString(src_out), " from ", edge->src()->name(),
        ":", edge->src_output(), " incompatible with expected ",
        DataTypeString(dst_in), "."));
    return false;
  }
  return true;
}

}  // namespace

// ----------------------------------------------------------------------------
// ConvertGraphDefToGraph
// ----------------------------------------------------------------------------

Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
                              const GraphDef& gdef, Graph* g) {
  Status status;
  GraphConstructor constructor(opts, &gdef, g, &status);
  return status;
}

// ----------------------------------------------------------------------------
// CopyGraph
// ----------------------------------------------------------------------------
void CopyGraph(const Graph& src, Graph* dest) {
  for (Node* n : dest->nodes()) {
    CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
  }

  // Copy the nodes
  std::unordered_map<Node*, Node*>
      node_map;  // "Node in src" -> "Node in *dest"
  node_map[src.source_node()] = dest->source_node();
  node_map[src.sink_node()] = dest->sink_node();
  for (Node* n : src.nodes()) {
    if (n->IsSource() || n->IsSink()) continue;
    CHECK(n->IsOp());
    node_map[n] = dest->CopyNode(n);
  }

  // Copy the edges
  for (const Edge* e : src.edges()) {
    Node* src_copy = node_map[e->src()];
    Node* dst_copy = node_map[e->dst()];
    dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
  }
}

}  // namespace tensorflow