aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_def_builder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/graph_def_builder.cc')
-rw-r--r--tensorflow/core/graph/graph_def_builder.cc121
1 files changed, 121 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc
new file mode 100644
index 0000000000..979604f948
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder.cc
@@ -0,0 +1,121 @@
+#include "tensorflow/core/graph/graph_def_builder.h"
+
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+GraphDefBuilder::Options::Options(Graph* graph, Status* status)
+ : graph_(graph), status_(status) {}
+GraphDefBuilder::Options::~Options() {}
+
+GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
+ StringPiece name) const {
+ return Options(*this).WithNameImpl(name);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
+ StringPiece device) const {
+ return Options(*this).WithDeviceImpl(device);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
+ Node* control_input) const {
+ return Options(*this).WithControlInputImpl(control_input);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
+ gtl::ArraySlice<Node*> control_inputs) const {
+ return Options(*this).WithControlInputsImpl(control_inputs);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
+ StringPiece name) {
+ name_ = name.ToString();
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
+ StringPiece device) {
+ device_ = device.ToString();
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
+ Node* control_input) {
+ control_inputs_.push_back(control_input);
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
+ gtl::ArraySlice<Node*> control_inputs) {
+ control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
+ control_inputs.end());
+ return *this;
+}
+
+Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
+ if (status_.ok()) {
+ graph_.ToGraphDef(graph_def);
+ }
+ return status_;
+}
+
+Status GraphDefBuilder::ToGraph(Graph* graph) const {
+ if (status_.ok()) {
+ GraphDef graph_def;
+ graph_.ToGraphDef(&graph_def);
+ GraphConstructorOptions opts;
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph));
+ }
+ return status_;
+}
+
+string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
+ if (name_.empty()) return graph_->NewName(op);
+ return name_;
+}
+
+Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
+ builder->ControlInputs(control_inputs_);
+ if (!device_.empty()) builder->Device(device_);
+ for (const auto& attr : attrs_) {
+ builder->Attr(attr.first, attr.second);
+ }
+
+ Node* returned_node;
+ UpdateStatus(builder->Finalize(graph_, &returned_node));
+ return returned_node;
+}
+
+void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
+ if (status_ == nullptr) {
+ TF_CHECK_OK(status);
+ } else {
+ status_->Update(status);
+ }
+}
+
+namespace ops {
+
+Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Node* UnaryOp(const string& op_name, NodeOut input,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ node_builder.Input(input);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ node_builder.Input(a).Input(b);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+} // end namespace ops
+} // end namespace tensorflow