aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/node_def_util.h
blob: fce6fd243381573d359da019d929c8cc891aa526 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_

#include <string>
#include <unordered_map>

#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"

namespace tensorflow {

// Produce a human-readable version of a NodeDef that is more concise
// than a text-format proto.
string SummarizeNodeDef(const NodeDef& node_def);

typedef protobuf::Map<string, AttrValue> AttrValueMap;

// Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value.
template <class T>
void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) {
  AttrValue attr_value;
  SetAttrValue(std::forward<T>(value), &attr_value);
  node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
}

// Version to workaround C++'s "perfect" forwarding not being able to
// forward {...} initialization.
template <class T>
void AddNodeAttr(const string& name, std::initializer_list<T> value,
                 NodeDef* node_def) {
  AttrValue attr_value;
  SetAttrValue(value, &attr_value);
  node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
}

class AttrSlice {
 public:
  AttrSlice(const NodeDef& node_def)  // NOLINT(runtime/explicit)
      : ndef_(&node_def),
        attrs_(&ndef_->attr()) {}

  explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {}

  // Returns the attr with attr_name if found.  Otherwise, returns
  // nullptr.
  const AttrValue* Find(const string& attr_name) const;

  // Returns the attr_value for attr_name if found. Otherwise, returns a
  // NotFound status.
  Status Find(const string& attr_name, const AttrValue** attr_value) const;

 private:
  const NodeDef* ndef_ = nullptr;
  const AttrValueMap* attrs_;
};

// Look up the attr with name attr_name and set *value to its value.  If no
// attr with attr_name is found in node_def, or the attr does not have
// a matching type, a non-ok status will be returned.
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   string* value);  // type: "string"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   int64* value);  // type: "int"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   int32* value);  // type: "int"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   float* value);  // type: "float"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   bool* value);  // type: "bool"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   DataType* value);  // type: "type"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   TensorShapeProto* value);  // type: "shape"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   TensorShape* value);  // type: "shape"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   Tensor* value);  // type: "tensor"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<string>* value);  // type "list(string)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<int64>* value);  // type "list(int)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<int32>* value);  // type "list(int)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<float>* value);  // type "list(float)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<bool>* value);  // type "list(bool)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<DataType>* value);  // type "list(type)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   DataTypeVector* value);  // type "list(type)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<TensorShapeProto>* value);  // type "list(shape)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<TensorShape>* value);  // type "list(shape)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<Tensor>* value);  // type: "list(tensor)"

// This version avoids copying the TensorProto.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   const TensorProto** value);  // type: "tensor"

// This version avoids copying the NameAttrList.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   const NameAttrList** value);  // type: "func"

// Computes the input and output types for a specific node, for
// attr-style ops.
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
                         DataTypeVector* inputs, DataTypeVector* outputs);

// Validates that the NodeDef:
// * Defines all expected attrs from the OpDef.
// * All attrs satisfies constraints from the OpDef.
// * Has a signature matching SignatureForNode().
// etc.
Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);

// Computes the mapping from input/output argument name to the
// corresponding input/output index range.  For example,
// input "foo" coresponds to input indices
//   [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
                         NameRangeMap* inputs, NameRangeMap* outputs);

// Adds default values to *node_def for unspecified attrs from op_def.
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);

// Validates the syntax of a NodeDef provided externally.
//
// The following is an EBNF-style syntax for NodeDef objects. Note that
// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
// which contain many other fields that are not (currently) validated.
//
// Node         = NodeName, Inputs
// Inputs       = ( DataInput * ), ( ControlInput * )
// DataInput    = NodeName, ( ":", [1-9], [0-9] * ) ?
// ControlInput = "^", NodeName
// NodeName     = [A-Za-z0-9.], [A-Za-z0-9_./] *
Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);

// Returns "status" with kernel's NodeDef attached as additional text
// in the error message.
Status AttachDef(const Status& status, const NodeDef& node_def);

}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_