aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/node_builder.h
blob: d576985a232da5f402b2e2d26bac1d0e1306f82f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
#define TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_

#include <vector>
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"

namespace tensorflow {

// This is a helper for creating a Node and adding it to a Graph.
// Internally, it uses a NodeDefBuilder to automatically set attrs
// that can be inferred from the inputs, and use default values
// (where they exist) for unspecified attrs.  Example usage:
//
//  Node* node;
//  Status status = NodeBuilder(node_name, op_name)
//                           .Input(...)
//                           .Attr(...)
//                           .Finalize(&graph, &node);
//  if (!status.ok()) return status;
//  // Use node here.
class NodeBuilder {
 public:
  // For specifying the output of a Node to provide to one of the Input()
  // functions below.  It supports both regular inputs (where you are
  // connecting to an existing Node*), and inputs from outside the graph
  // (or haven't been added to the graph yet, like back edges, where
  // you don't have a Node*). Both types can be mixed, e.g. in an
  // ArraySlice.
  struct NodeOut {
    // For referencing an existing Node.
    NodeOut(Node* n, int32 i = 0);

    // For referencing Nodes not in the graph being built. It is
    // useful when preparing a graph for ExtendSession or creating a
    // back edge to a node that hasn't been added to the graph yet,
    // but will be.
    NodeOut(StringPiece name, int32 i, DataType t);

    // Default constructor for std::vector<NodeOut>.
    NodeOut();

    Node* node;
    // error is set to true if:
    // * the NodeOut was default constructed and never overwritten,
    // * a nullptr Node* was passed to the NodeOut constructor, or
    // * an out-of-range index was passed to the NodeOut constructor.
    bool error;
    string name;
    int32 index;
    DataType dt;
  };

  // Specify the name and the Op (either via an OpDef or the name of
  // the Op plus a registry) for the Node.  Other fields are
  // specified by calling the methods below.
  // REQUIRES: The OpDef must satisfy ValidateOpDef().
  NodeBuilder(StringPiece name, StringPiece op_name,
              const OpRegistryInterface* op_registry = OpRegistry::Global());
  NodeBuilder(StringPiece name, const OpDef* op_def);

  // Create a NodeBuilder from an existing NodeDefBuilder.
  NodeBuilder(const NodeDefBuilder& def_builder);

  // You must call one Input() function per input_arg in the Op,
  // *and in the same order as the input_args appear in the OpDef.*

  // For inputs that take a single tensor.
  NodeBuilder& Input(Node* src_node, int src_index = 0);
  NodeBuilder& Input(NodeOut src);

  // For inputs that take a list of tensors.
  NodeBuilder& Input(gtl::ArraySlice<NodeOut> src_list);

  // Require that this node run after src_node(s).
  NodeBuilder& ControlInput(Node* src_node);
  NodeBuilder& ControlInputs(gtl::ArraySlice<Node*> src_nodes);

  // Sets the "requested device spec" in the NodeDef (not the
  // "assigned device" in the Node).
  NodeBuilder& Device(StringPiece device_spec);

  // Sets the device name in the "assigned device" field in tensorflow::Node.
  NodeBuilder& AssignedDevice(StringPiece device);

  // Set the value of an attr.  attr_name must match the name of one of
  // attrs defined by the Op, and value must have the corresponding type
  // (see SetAttrValue() in ../framework/attr_value_util.h for legal
  // types for value).  Note that attrs will be set automatically if
  // they can be determined by the inputs.
  template <class T>
  NodeBuilder& Attr(StringPiece attr_name, T&& value);
  template <class T>
  NodeBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);

  // Validates the described node and adds it to *graph, adding edges
  // for all (non-back) inputs.  If created_node is not nullptr,
  // *created_node will be set to the new node (or nullptr on error).
  Status Finalize(Graph* graph, Node** created_node) const;

  // Accessors for the values set in the constructor.
  const string& node_name() const { return def_builder_.node_name(); }
  const OpDef& op_def() const { return def_builder_.op_def(); }

 private:
  static DataType SafeGetOutput(const Node* node, int i, bool* error) {
    if (node != nullptr && i >= 0 && i < node->num_outputs()) {
      *error = false;
      return node->output_type(i);
    } else {
      *error = true;
      return DT_FLOAT;
    }
  }

  // If SafeGetOutput indicates a range error, add it to errors_.
  void AddIndexError(const Node* node, int i);

  // Set *dt and returns true if i is in range. Combines
  // SafeGetOutput() and AddIndexError().
  bool GetOutputType(const Node* node, int i, DataType* dt);

  NodeDefBuilder def_builder_;
  std::vector<NodeOut> inputs_;
  std::vector<Node*> control_inputs_;
  std::vector<string> errors_;
  string assigned_device_;
};

// IMPLEMENTATION -------------------------------------------------------------

template <class T>
NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, T&& value) {
  def_builder_.Attr(attr_name, std::forward<T>(value));
  return *this;
}

template <class T>
NodeBuilder& NodeBuilder::Attr(StringPiece attr_name,
                               std::initializer_list<T> value) {
  def_builder_.Attr(attr_name, value);
  return *this;
}

}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_