aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/node_def_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/node_def_util.h')
-rw-r--r--tensorflow/core/framework/node_def_util.h157
1 files changed, 157 insertions, 0 deletions
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
new file mode 100644
index 0000000000..fce6fd2433
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util.h
@@ -0,0 +1,157 @@
+#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// Produce a human-readable version of a NodeDef that is more concise
+// than a text-format proto.
+string SummarizeNodeDef(const NodeDef& node_def);
+
+typedef protobuf::Map<string, AttrValue> AttrValueMap;
+
+// Adds an attr with name <name> and value <value> to *node_def.
+// The type of the attr is based on the type of value.
+template <class T>
+void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) {
+ AttrValue attr_value;
+ SetAttrValue(std::forward<T>(value), &attr_value);
+ node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+}
+
+// Version to workaround C++'s "perfect" forwarding not being able to
+// forward {...} initialization.
+template <class T>
+void AddNodeAttr(const string& name, std::initializer_list<T> value,
+ NodeDef* node_def) {
+ AttrValue attr_value;
+ SetAttrValue(value, &attr_value);
+ node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+}
+
+class AttrSlice {
+ public:
+ AttrSlice(const NodeDef& node_def) // NOLINT(runtime/explicit)
+ : ndef_(&node_def),
+ attrs_(&ndef_->attr()) {}
+
+ explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {}
+
+ // Returns the attr with attr_name if found. Otherwise, returns
+ // nullptr.
+ const AttrValue* Find(const string& attr_name) const;
+
+ // Returns the attr_value for attr_name if found. Otherwise, returns a
+ // NotFound status.
+ Status Find(const string& attr_name, const AttrValue** attr_value) const;
+
+ private:
+ const NodeDef* ndef_ = nullptr;
+ const AttrValueMap* attrs_;
+};
+
+// Look up the attr with name attr_name and set *value to its value. If no
+// attr with attr_name is found in node_def, or the attr does not have
+// a matching type, a non-ok status will be returned.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ string* value); // type: "string"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ int64* value); // type: "int"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ int32* value); // type: "int"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ float* value); // type: "float"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ bool* value); // type: "bool"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataType* value); // type: "type"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ TensorShapeProto* value); // type: "shape"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ TensorShape* value); // type: "shape"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ Tensor* value); // type: "tensor"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<string>* value); // type "list(string)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<int64>* value); // type "list(int)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<int32>* value); // type "list(int)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<float>* value); // type "list(float)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<bool>* value); // type "list(bool)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<DataType>* value); // type "list(type)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataTypeVector* value); // type "list(type)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<TensorShapeProto>* value); // type "list(shape)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<TensorShape>* value); // type "list(shape)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<Tensor>* value); // type: "list(tensor)"
+
+// This version avoids copying the TensorProto.
+// REQUIRES: Must not use *value beyond the lifetime of node_def.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const TensorProto** value); // type: "tensor"
+
+// This version avoids copying the NameAttrList.
+// REQUIRES: Must not use *value beyond the lifetime of node_def.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const NameAttrList** value); // type: "func"
+
+// Computes the input and output types for a specific node, for
+// attr-style ops.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs, DataTypeVector* outputs);
+
+// Validates that the NodeDef:
+// * Defines all expected attrs from the OpDef.
+// * All attrs satisfies constraints from the OpDef.
+// * Has a signature matching SignatureForNode().
+// etc.
+Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
+
+// Computes the mapping from input/output argument name to the
+// corresponding input/output index range. For example,
+// input "foo" coresponds to input indices
+// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
+typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
+Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs);
+
+// Adds default values to *node_def for unspecified attrs from op_def.
+void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
+
+// Validates the syntax of a NodeDef provided externally.
+//
+// The following is an EBNF-style syntax for NodeDef objects. Note that
+// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
+// which contain many other fields that are not (currently) validated.
+//
+// Node = NodeName, Inputs
+// Inputs = ( DataInput * ), ( ControlInput * )
+// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
+// ControlInput = "^", NodeName
+// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
+Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
+
+// Returns "status" with kernel's NodeDef attached as additional text
+// in the error message.
+Status AttachDef(const Status& status, const NodeDef& node_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_