diff options
Diffstat (limited to 'tensorflow/core/graph/graph_def_builder.h')
-rw-r--r-- | tensorflow/core/graph/graph_def_builder.h | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h new file mode 100644 index 0000000000..bb72f9eea6 --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder.h @@ -0,0 +1,181 @@ +#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ +#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ + +#include <vector> +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Given a function like: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// if (opts.HaveError()) return nullptr; +// static const string kOpName = "Identity"; +// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, +// opts.op_registry()); +// node_builder.Input(input); +// return opts.FinalizeBuilder(&node_builder); +// } +// } // namspace ops +// +// // Or, alternatively: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// static const string kOpName = "Identity"; +// return UnaryOp(kOpName, input, opts); +// } +// } // namspace ops +// +// You call it like: +// GraphDefBuilder b; +// using namespace ::tensorflow::ops; // NOLINT(build/namespaces) +// Node* a = Const(7, b.opts()); +// // Note: WithName() returns a copy, opts is unchanged. +// Node* b = Const(5, b.opts().WithName("control-input")); +// Node* c = Identity(a, b.opts().WithControlInput(b)); +// GraphDef graph_def; +// Status status = b.ToGraphDef(&graph_def); +// if (!status.ok()) { /* Handle error */ } +// +// In tests you can skip the status handling via: +// GraphDefBuilder b(GraphDefBuilder::kFailImmediately); +// ... +// b.ToGraphDef(&graph_def); + +class GraphDefBuilder { + public: + // Options for adding a Node to a Graph. + class Options { + public: + // Sets the Graph (that Nodes will be added to) and the status. The + // status may be set to nullptr, in which case errors cause CHECK + // failures. The graph and status must outlive *this. + Options(Graph* graph, Status* status); + ~Options(); + + // Methods for setting options. These are const methods: they + // return a copy of *this with the option set. + Options WithName(StringPiece name) const; + Options WithDevice(StringPiece device) const; + Options WithControlInput(Node* control_input) const; + Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const; + + // Override the default value for an optional attr. + template <class T> + Options WithAttr(StringPiece attr_name, T&& value) const { + return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value)); + } + // Note: overload needed to allow {...} expressions for value. + template <class T> + Options WithAttr(StringPiece attr_name, + std::initializer_list<T> value) const { + return WithAttr<std::initializer_list<T>>(attr_name, std::move(value)); + } + + // Methods for using options from a function that creates a Node. + + // Returns true if the status associated with *this has an error. + // Use this to skip processing that may depend on prior results. + bool HaveError() const { return status_ != nullptr && !status_->ok(); } + + // Given the Op type name, return a name for a node of that type. + // Uses the value set in WithName() if that has been called. Otherwise, + // returns a name built out of the Op type name. + string GetNameForOp(StringPiece op) const; + + // Sets the device, adds control inputs, adds attrs, and calls Finalize(). + // If Finalize returns an error, it is saved and this function returns + // nullptr. + Node* FinalizeBuilder(NodeBuilder* builder) const; + + // Updates the associated status, if any, or calls TF_CHECK_OK if none. + void UpdateStatus(const Status& status) const; + + // Accessor + const OpRegistryInterface* op_registry() const { + return graph_->op_registry(); + } + + private: + Options WithNameImpl(StringPiece name); + Options WithDeviceImpl(StringPiece device); + Options WithControlInputImpl(Node* control_input); + Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs); + template <class T> + Options WithAttrImpl(StringPiece name, T&& value) { + attrs_.emplace_back(name.ToString(), AttrValue()); + SetAttrValue(std::forward<T>(value), &attrs_.back().second); + return *this; + } + + Graph* const graph_; + Status* const status_; + string name_; + string device_; + std::vector<Node*> control_inputs_; + std::vector<std::pair<string, AttrValue>> attrs_; + }; + + // Start building a new graph. + explicit GraphDefBuilder( + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), opts_(&graph_, &status_) {} + + // For use in tests, where you want to fail immediately on error instead + // of checking the status at the end. + enum TestFailImmediatelyType { kFailImmediately }; + explicit GraphDefBuilder( + TestFailImmediatelyType, + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), opts_(&graph_, nullptr) {} + + // Gets the Options with the associated Graph and Status. + const Options& opts() const { return opts_; } + + // Once all the nodes have been added, call this to get whether it was + // successful, and if so fill *graph_def. + Status ToGraphDef(GraphDef* graph_def) const; + + // Like ToGraphDef(), but converts to a Graph (using the default + // GraphConstructorOptions). + // TODO(josh11b): Make this faster; right now it converts + // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds + // edges from the source and to the sink node, resolves back edges + // by name), and makes sure the resulting graph is valid. + Status ToGraph(Graph* graph) const; + + private: + Graph graph_; + Status status_; + Options opts_; +}; + +namespace ops { + +// A NodeOut may either be a regular input or back input. Regular +// inputs are specified via either a Node* or a Node* and an output +// index. Back inputs are specified by a node name, output index, and +// output type. +typedef NodeBuilder::NodeOut NodeOut; + +// For adding an Op with no inputs to a GraphDefBuilder. +Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); + +// For adding an Op with one input to a GraphDefBuilder. +Node* UnaryOp(const string& op_name, NodeOut input, + const GraphDefBuilder::Options& opts); + +// For adding an Op with two inputs to a GraphDefBuilder. +Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, + const GraphDefBuilder::Options& opts); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ |