diff options
Diffstat (limited to 'tensorflow/core/framework/node_def_builder.cc')
-rw-r--r-- | tensorflow/core/framework/node_def_builder.cc | 49 |
1 files changed, 41 insertions, 8 deletions
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index b6f5838528..f3091ad286 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -22,11 +22,24 @@ limitations under the License. namespace tensorflow { -NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name, +NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt) + : node(n.ToString()), index(i), data_type(dt) {} + +NodeDefBuilder::NodeOut::NodeOut() { + // uninitialized, call Reset() before use. +} + +void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) { + node = n.ToString(); + index = i; + data_type = dt; +} + +NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, const OpRegistryInterface* op_registry) { - node_def_.set_name(name); + node_def_.set_name(name.ToString()); Status status; - op_def_ = op_registry->LookUp(op_name, &status); + op_def_ = op_registry->LookUp(op_name.ToString(), &status); if (op_def_ == nullptr) { errors_.push_back(status.error_message()); inputs_specified_ = 0; @@ -35,9 +48,9 @@ NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name, } } -NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def) +NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def) : op_def_(op_def) { - node_def_.set_name(name); + node_def_.set_name(name.ToString()); Initialize(); } @@ -72,7 +85,7 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { } void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, - const string& src_node, int src_index, + StringPiece src_node, int src_index, DataType dt) { AddInput(src_node, src_index); @@ -129,7 +142,7 @@ void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, } } -void NodeDefBuilder::AddInput(const string& src_node, int src_index) { +void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) { if (src_node.empty()) { errors_.push_back("Empty input node name"); } else if (src_node[0] == '^') { @@ -138,7 +151,7 @@ void NodeDefBuilder::AddInput(const string& src_node, int src_index) { } else if (src_index > 0) { node_def_.add_input(strings::StrCat(src_node, ":", src_index)); } else { - node_def_.add_input(src_node); + node_def_.add_input(src_node.ToString()); } } @@ -160,6 +173,16 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, } } +NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) { + control_inputs_.push_back(src_node.ToString()); + return *this; +} + +NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) { + node_def_.set_device(device_spec.ToString()); + return *this; +} + Status NodeDefBuilder::Finalize(NodeDef* node_def) const { const std::vector<string>* errors_ptr = &errors_; std::vector<string> errors_storage; @@ -206,4 +229,14 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const { } } +void NodeDefBuilder::CheckInconsistency(StringPiece attr_name, + const AttrValue& found, + const AttrValue& attr_value) { + if (!AreAttrValuesEqual(found, attr_value)) { + errors_.push_back(strings::StrCat( + "Inconsistent values for attr '", attr_name, "' ", + SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value))); + } +} + } // namespace tensorflow |