aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_def_builder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/graph_def_builder.h')
-rw-r--r--tensorflow/core/graph/graph_def_builder.h181
1 files changed, 181 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
new file mode 100644
index 0000000000..bb72f9eea6
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -0,0 +1,181 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+
+#include <vector>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// Given a function like:
+// namespace ops {
+// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
+// if (opts.HaveError()) return nullptr;
+// static const string kOpName = "Identity";
+// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
+// opts.op_registry());
+// node_builder.Input(input);
+// return opts.FinalizeBuilder(&node_builder);
+// }
+// } // namspace ops
+//
+// // Or, alternatively:
+// namespace ops {
+// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
+// static const string kOpName = "Identity";
+// return UnaryOp(kOpName, input, opts);
+// }
+// } // namspace ops
+//
+// You call it like:
+// GraphDefBuilder b;
+// using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+// Node* a = Const(7, b.opts());
+// // Note: WithName() returns a copy, opts is unchanged.
+// Node* b = Const(5, b.opts().WithName("control-input"));
+// Node* c = Identity(a, b.opts().WithControlInput(b));
+// GraphDef graph_def;
+// Status status = b.ToGraphDef(&graph_def);
+// if (!status.ok()) { /* Handle error */ }
+//
+// In tests you can skip the status handling via:
+// GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+// ...
+// b.ToGraphDef(&graph_def);
+
+class GraphDefBuilder {
+ public:
+ // Options for adding a Node to a Graph.
+ class Options {
+ public:
+ // Sets the Graph (that Nodes will be added to) and the status. The
+ // status may be set to nullptr, in which case errors cause CHECK
+ // failures. The graph and status must outlive *this.
+ Options(Graph* graph, Status* status);
+ ~Options();
+
+ // Methods for setting options. These are const methods: they
+ // return a copy of *this with the option set.
+ Options WithName(StringPiece name) const;
+ Options WithDevice(StringPiece device) const;
+ Options WithControlInput(Node* control_input) const;
+ Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
+
+ // Override the default value for an optional attr.
+ template <class T>
+ Options WithAttr(StringPiece attr_name, T&& value) const {
+ return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
+ }
+ // Note: overload needed to allow {...} expressions for value.
+ template <class T>
+ Options WithAttr(StringPiece attr_name,
+ std::initializer_list<T> value) const {
+ return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
+ }
+
+ // Methods for using options from a function that creates a Node.
+
+ // Returns true if the status associated with *this has an error.
+ // Use this to skip processing that may depend on prior results.
+ bool HaveError() const { return status_ != nullptr && !status_->ok(); }
+
+ // Given the Op type name, return a name for a node of that type.
+ // Uses the value set in WithName() if that has been called. Otherwise,
+ // returns a name built out of the Op type name.
+ string GetNameForOp(StringPiece op) const;
+
+ // Sets the device, adds control inputs, adds attrs, and calls Finalize().
+ // If Finalize returns an error, it is saved and this function returns
+ // nullptr.
+ Node* FinalizeBuilder(NodeBuilder* builder) const;
+
+ // Updates the associated status, if any, or calls TF_CHECK_OK if none.
+ void UpdateStatus(const Status& status) const;
+
+ // Accessor
+ const OpRegistryInterface* op_registry() const {
+ return graph_->op_registry();
+ }
+
+ private:
+ Options WithNameImpl(StringPiece name);
+ Options WithDeviceImpl(StringPiece device);
+ Options WithControlInputImpl(Node* control_input);
+ Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
+ template <class T>
+ Options WithAttrImpl(StringPiece name, T&& value) {
+ attrs_.emplace_back(name.ToString(), AttrValue());
+ SetAttrValue(std::forward<T>(value), &attrs_.back().second);
+ return *this;
+ }
+
+ Graph* const graph_;
+ Status* const status_;
+ string name_;
+ string device_;
+ std::vector<Node*> control_inputs_;
+ std::vector<std::pair<string, AttrValue>> attrs_;
+ };
+
+ // Start building a new graph.
+ explicit GraphDefBuilder(
+ const OpRegistryInterface* op_registry = OpRegistry::Global())
+ : graph_(op_registry), opts_(&graph_, &status_) {}
+
+ // For use in tests, where you want to fail immediately on error instead
+ // of checking the status at the end.
+ enum TestFailImmediatelyType { kFailImmediately };
+ explicit GraphDefBuilder(
+ TestFailImmediatelyType,
+ const OpRegistryInterface* op_registry = OpRegistry::Global())
+ : graph_(op_registry), opts_(&graph_, nullptr) {}
+
+ // Gets the Options with the associated Graph and Status.
+ const Options& opts() const { return opts_; }
+
+ // Once all the nodes have been added, call this to get whether it was
+ // successful, and if so fill *graph_def.
+ Status ToGraphDef(GraphDef* graph_def) const;
+
+ // Like ToGraphDef(), but converts to a Graph (using the default
+ // GraphConstructorOptions).
+ // TODO(josh11b): Make this faster; right now it converts
+ // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
+ // edges from the source and to the sink node, resolves back edges
+ // by name), and makes sure the resulting graph is valid.
+ Status ToGraph(Graph* graph) const;
+
+ private:
+ Graph graph_;
+ Status status_;
+ Options opts_;
+};
+
+namespace ops {
+
+// A NodeOut may either be a regular input or back input. Regular
+// inputs are specified via either a Node* or a Node* and an output
+// index. Back inputs are specified by a node name, output index, and
+// output type.
+typedef NodeBuilder::NodeOut NodeOut;
+
+// For adding an Op with no inputs to a GraphDefBuilder.
+Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
+
+// For adding an Op with one input to a GraphDefBuilder.
+Node* UnaryOp(const string& op_name, NodeOut input,
+ const GraphDefBuilder::Options& opts);
+
+// For adding an Op with two inputs to a GraphDefBuilder.
+Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
+ const GraphDefBuilder::Options& opts);
+
+} // namespace ops
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_