aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/node_def_builder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/node_def_builder.cc')
-rw-r--r--tensorflow/core/framework/node_def_builder.cc49
1 files changed, 41 insertions, 8 deletions
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index b6f5838528..f3091ad286 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -22,11 +22,24 @@ limitations under the License.
namespace tensorflow {
-NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
+NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
+ : node(n.ToString()), index(i), data_type(dt) {}
+
+NodeDefBuilder::NodeOut::NodeOut() {
+ // uninitialized, call Reset() before use.
+}
+
+void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
+ node = n.ToString();
+ index = i;
+ data_type = dt;
+}
+
+NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry) {
- node_def_.set_name(name);
+ node_def_.set_name(name.ToString());
Status status;
- op_def_ = op_registry->LookUp(op_name, &status);
+ op_def_ = op_registry->LookUp(op_name.ToString(), &status);
if (op_def_ == nullptr) {
errors_.push_back(status.error_message());
inputs_specified_ = 0;
@@ -35,9 +48,9 @@ NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
}
}
-NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def)
+NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
: op_def_(op_def) {
- node_def_.set_name(name);
+ node_def_.set_name(name.ToString());
Initialize();
}
@@ -72,7 +85,7 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
}
void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
- const string& src_node, int src_index,
+ StringPiece src_node, int src_index,
DataType dt) {
AddInput(src_node, src_index);
@@ -129,7 +142,7 @@ void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
}
}
-void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
+void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
if (src_node.empty()) {
errors_.push_back("Empty input node name");
} else if (src_node[0] == '^') {
@@ -138,7 +151,7 @@ void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
} else if (src_index > 0) {
node_def_.add_input(strings::StrCat(src_node, ":", src_index));
} else {
- node_def_.add_input(src_node);
+ node_def_.add_input(src_node.ToString());
}
}
@@ -160,6 +173,16 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
}
}
+NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
+ control_inputs_.push_back(src_node.ToString());
+ return *this;
+}
+
+NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
+ node_def_.set_device(device_spec.ToString());
+ return *this;
+}
+
Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
const std::vector<string>* errors_ptr = &errors_;
std::vector<string> errors_storage;
@@ -206,4 +229,14 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
}
}
+void NodeDefBuilder::CheckInconsistency(StringPiece attr_name,
+ const AttrValue& found,
+ const AttrValue& attr_value) {
+ if (!AreAttrValuesEqual(found, attr_value)) {
+ errors_.push_back(strings::StrCat(
+ "Inconsistent values for attr '", attr_name, "' ",
+ SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value)));
+ }
+}
+
} // namespace tensorflow