aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_def_builder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/op_def_builder.h')
-rw-r--r--tensorflow/core/framework/op_def_builder.h109
1 files changed, 109 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
new file mode 100644
index 0000000000..017338c508
--- /dev/null
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -0,0 +1,109 @@
+// Class and associated machinery for specifying an Op's OpDef for Op
+// registration.
+
+#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Builder class passed to the REGISTER_OP() macro.
+class OpDefBuilder {
+ public:
+ // Constructs an OpDef with just the name field set.
+ explicit OpDefBuilder(StringPiece op_name);
+
+ // Adds an attr to this OpDefBuilder (and returns *this). The spec has
+ // format "<name>:<type>" or "<name>:<type>=<default>"
+ // where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*
+ // (by convention only using capital letters for attrs that can be inferred)
+ // <type> can be:
+ // "string", "int", "float", "bool", "type", "shape", or "tensor"
+ // "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}"
+ // (meaning "type" with a restriction on valid values)
+ // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
+ // (meaning "string" with a restriction on valid values)
+ // "list(string)", ..., "list(tensor)", "list(numbertype)", ...
+ // (meaning lists of the above types)
+ // "int >= 2" (meaning "int" with a restriction on valid values)
+ // "list(string) >= 2", "list(int) >= 2"
+ // (meaning "list(string)" / "list(int)" with length at least 2)
+ // <default>, if included, should use the Proto text format
+ // of <type>. For lists use [a, b, c] format.
+ //
+ // Note that any attr specifying the length of an input or output will
+ // get a default minimum of 1 unless the >= # syntax is used.
+ //
+ // TODO(josh11b): Perhaps support restrictions and defaults as optional
+ // extra arguments to Attr() instead of encoding them in the spec string.
+ // TODO(josh11b): Would like to have better dtype handling for tensor attrs:
+ // * Ability to say the type of an input/output matches the type of
+ // the tensor.
+ // * Ability to restrict the type of the tensor like the existing
+ // restrictions for type attrs.
+ // Perhaps by linking the type of the tensor to a type attr?
+ OpDefBuilder& Attr(StringPiece spec);
+
+ // Adds an input or ouput to this OpDefBuilder (and returns *this).
+ // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
+ // where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
+ // * For a single tensor: <type>
+ // * For a sequence of tensors with the same type: <number>*<type>
+ // * For a sequence of tensors with different types: <type-list>
+ // Where:
+ // <type> is either one of "float", "int32", "string", ...
+ // or the name of an attr (see above) with type "type".
+ // <number> is the name of an attr with type "int".
+ // <type-list> is the name of an attr with type "list(type)".
+ // TODO(josh11b): Indicate Ref() via an optional argument instead of
+ // in the spec?
+ // TODO(josh11b): SparseInput() and SparseOutput() matching the Python
+ // handling?
+ OpDefBuilder& Input(StringPiece spec);
+ OpDefBuilder& Output(StringPiece spec);
+
+ // Turns on the indicated boolean flag in this OpDefBuilder (and
+ // returns *this).
+ OpDefBuilder& SetIsCommutative();
+ OpDefBuilder& SetIsAggregate();
+ OpDefBuilder& SetIsStateful();
+ OpDefBuilder& SetAllowsUninitializedInput();
+
+ // Adds docs to this OpDefBuilder (and returns *this).
+ // Docs have the format:
+ // <1-line summary>
+ // <rest of the description>
+ // <name>: <description of name>
+ // <name>: <description of name>
+ // <if long, indent the description on subsequent lines>
+ // Where <name> is the name of an attr, input, or output. Please
+ // wrap docs at 72 columns so that it may be indented in the
+ // generated output. For tensor inputs or outputs (not attrs), you
+ // may start the description with an "=" (like name:= <description>)
+ // to suppress the automatically-generated type documentation in
+ // generated output.
+ OpDefBuilder& Doc(StringPiece text);
+
+ // Sets *op_def to the requested OpDef, or returns an error.
+ // Must be called after all of the above methods.
+ // Note that OpDefBuilder only reports parsing errors. You should also
+ // call ValidateOpDef() to detect other problems.
+ Status Finalize(OpDef* op_def) const;
+
+ private:
+ OpDef op_def_;
+ std::vector<string> attrs_;
+ std::vector<string> inputs_;
+ std::vector<string> outputs_;
+ string doc_;
+ std::vector<string> errors_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_