diff options
Diffstat (limited to 'tensorflow/core/framework/op_def_builder.h')
-rw-r--r-- | tensorflow/core/framework/op_def_builder.h | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h new file mode 100644 index 0000000000..017338c508 --- /dev/null +++ b/tensorflow/core/framework/op_def_builder.h @@ -0,0 +1,109 @@ +// Class and associated machinery for specifying an Op's OpDef for Op +// registration. + +#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ + +#include <string> +#include <vector> +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Builder class passed to the REGISTER_OP() macro. +class OpDefBuilder { + public: + // Constructs an OpDef with just the name field set. + explicit OpDefBuilder(StringPiece op_name); + + // Adds an attr to this OpDefBuilder (and returns *this). The spec has + // format "<name>:<type>" or "<name>:<type>=<default>" + // where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]* + // (by convention only using capital letters for attrs that can be inferred) + // <type> can be: + // "string", "int", "float", "bool", "type", "shape", or "tensor" + // "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}" + // (meaning "type" with a restriction on valid values) + // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" + // (meaning "string" with a restriction on valid values) + // "list(string)", ..., "list(tensor)", "list(numbertype)", ... + // (meaning lists of the above types) + // "int >= 2" (meaning "int" with a restriction on valid values) + // "list(string) >= 2", "list(int) >= 2" + // (meaning "list(string)" / "list(int)" with length at least 2) + // <default>, if included, should use the Proto text format + // of <type>. For lists use [a, b, c] format. + // + // Note that any attr specifying the length of an input or output will + // get a default minimum of 1 unless the >= # syntax is used. + // + // TODO(josh11b): Perhaps support restrictions and defaults as optional + // extra arguments to Attr() instead of encoding them in the spec string. + // TODO(josh11b): Would like to have better dtype handling for tensor attrs: + // * Ability to say the type of an input/output matches the type of + // the tensor. + // * Ability to restrict the type of the tensor like the existing + // restrictions for type attrs. + // Perhaps by linking the type of the tensor to a type attr? + OpDefBuilder& Attr(StringPiece spec); + + // Adds an input or ouput to this OpDefBuilder (and returns *this). + // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)" + // where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be: + // * For a single tensor: <type> + // * For a sequence of tensors with the same type: <number>*<type> + // * For a sequence of tensors with different types: <type-list> + // Where: + // <type> is either one of "float", "int32", "string", ... + // or the name of an attr (see above) with type "type". + // <number> is the name of an attr with type "int". + // <type-list> is the name of an attr with type "list(type)". + // TODO(josh11b): Indicate Ref() via an optional argument instead of + // in the spec? + // TODO(josh11b): SparseInput() and SparseOutput() matching the Python + // handling? + OpDefBuilder& Input(StringPiece spec); + OpDefBuilder& Output(StringPiece spec); + + // Turns on the indicated boolean flag in this OpDefBuilder (and + // returns *this). + OpDefBuilder& SetIsCommutative(); + OpDefBuilder& SetIsAggregate(); + OpDefBuilder& SetIsStateful(); + OpDefBuilder& SetAllowsUninitializedInput(); + + // Adds docs to this OpDefBuilder (and returns *this). + // Docs have the format: + // <1-line summary> + // <rest of the description> + // <name>: <description of name> + // <name>: <description of name> + // <if long, indent the description on subsequent lines> + // Where <name> is the name of an attr, input, or output. Please + // wrap docs at 72 columns so that it may be indented in the + // generated output. For tensor inputs or outputs (not attrs), you + // may start the description with an "=" (like name:= <description>) + // to suppress the automatically-generated type documentation in + // generated output. + OpDefBuilder& Doc(StringPiece text); + + // Sets *op_def to the requested OpDef, or returns an error. + // Must be called after all of the above methods. + // Note that OpDefBuilder only reports parsing errors. You should also + // call ValidateOpDef() to detect other problems. + Status Finalize(OpDef* op_def) const; + + private: + OpDef op_def_; + std::vector<string> attrs_; + std::vector<string> inputs_; + std::vector<string> outputs_; + string doc_; + std::vector<string> errors_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ |