diff options
author | 2017-11-03 18:31:22 -0700 | |
---|---|---|
committer | 2017-11-03 18:35:31 -0700 | |
commit | 11c8cc337eb5e41c1695e4d6f4e8b25cdd4d9545 (patch) | |
tree | c793a925186aff7cb96131bb0a5c44bed89f5996 /tensorflow/core/graph/graph_constructor.cc | |
parent | 7b2b720704ee802fd468334f640304e9036cf76c (diff) |
Add uniquify_names option to ImportGraphDef.
This option allows ImportGraphDef to mimic the behavior of the Python
import_graph_def function, which automatically creates unique node
names instead of raising an exception (this is due to the Python op
construction logic, not import_graph_def directly). This change is
a steps towards switching import_graph_def to use the C API version.
PiperOrigin-RevId: 174541334
Diffstat (limited to 'tensorflow/core/graph/graph_constructor.cc')
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 121 |
1 files changed, 100 insertions, 21 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 8fe4f535fb..753cb260e5 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -75,6 +75,7 @@ class GraphConstructor { prefix(in.prefix.empty() || StringPiece(in.prefix).ends_with("/") ? in.prefix : in.prefix + "/"), + uniquify_names(in.uniquify_names), input_map(in.input_map), skip_mapped_nodes(in.skip_mapped_nodes), control_dependencies(in.control_dependencies), @@ -86,6 +87,7 @@ class GraphConstructor { bool expect_device_spec; string prefix; + bool uniquify_names; std::map<TensorId, TensorId> input_map; bool skip_mapped_nodes; std::vector<string> control_dependencies; @@ -190,6 +192,20 @@ class GraphConstructor { void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists, NodeDef* node_def); + // Modifies `node_def` if its name isn't unique, or if any of its inputs' + // names have been uniquified. This must be called in topological order on all + // nodes. + void UniquifyNames(const std::vector<bool>& input_already_exists, + NodeDef* node_def); + + // Returns true if `name` already exists in `g_` (either as a node name or + // prefix). + bool NameExists(StringPiece name); + + // Returns a unique version of `original_name`, or `original_name` if it's + // already unique in the graph. + string FindUniqueName(StringPiece original_name); + // From constructor const Options opts_; const NodeDefSlice node_defs_; @@ -224,9 +240,16 @@ class GraphConstructor { // alternative implementation of std::unordered_map. std::unordered_map<StringPiece, NodeInfo, StringPiece::Hasher> gdef_nodes_; - // Mapping from node name to the existing node in g_ + // Mapping from node name to the existing node in g_. std::unordered_map<StringPiece, Node*, StringPiece::Hasher> existing_nodes_; + // Prefixes already used in the graph. + std::unordered_set<StringPiece, StringPiece::Hasher> existing_prefixes_; + + // Imported node names that have been uniquified. The key is the original + // name, the value is the new unique name. + std::unordered_map<string, string> uniquified_names_; + // Index of NodeDefs in node_defs_ with all inputs already converted. std::vector<int> ready_; @@ -281,6 +304,7 @@ bool NodeNameInValues(const std::vector<string>& control_dependencies, Status GraphConstructor::EnsureNoNameCollisions() { existing_nodes_.reserve(g_->num_nodes()); + // Populate existing_nodes_ and existing_prefixes_. for (Node* n : g_->nodes()) { bool already_exists = !existing_nodes_.insert({n->name(), n}).second; if (already_exists) { @@ -296,18 +320,22 @@ Status GraphConstructor::EnsureNoNameCollisions() { n->name(), "'"); } } + // Add all of node's prefixes to existing_prefixes_ (if it has any). + size_t idx = -1; + while ((idx = n->name().find('/', idx + 1)) != string::npos) { + StringPiece name(n->name()); + existing_prefixes_.insert(name.substr(0, idx)); + } } - if (opts_.prefix.empty() && opts_.importing) { + if (opts_.prefix.empty() && opts_.importing && !opts_.uniquify_names) { for (const NodeDef* n : node_defs_) { const string& name = n->name(); - if (existing_nodes_.find(name) != existing_nodes_.end()) { - return errors::InvalidArgument("Node '", name, + if (NameExists(name)) { + return errors::InvalidArgument("Node name '", name, "' already exists in the Graph"); } } } else if (!opts_.prefix.empty()) { - // Importing nodes with a prefix. No nodes should exist with the same - // prefix. StringPiece prefix_no_slash(opts_.prefix); prefix_no_slash.remove_suffix(1); if (!IsValidNodeName(prefix_no_slash, false)) { @@ -315,13 +343,11 @@ Status GraphConstructor::EnsureNoNameCollisions() { opts_.prefix, "' would lead to invalid node names"); } - for (const Node* n : g_->nodes()) { - if (StringPiece(n->name()).starts_with(opts_.prefix)) { - return errors::InvalidArgument( - "Import node name prefix conflicts with names of nodes already in " - "the Graph, such as '", - n->name(), "'"); - } + if (NameExists(prefix_no_slash)) { + return errors::InvalidArgument("Import node name prefix '", + prefix_no_slash, + "' conflicts with " + "name already used in the graph"); } } return Status::OK(); @@ -663,19 +689,18 @@ void GraphConstructor::AddControlDependencies( void GraphConstructor::AddPrefixToNodeDef( const std::vector<bool>& input_already_exists, NodeDef* node_def) { - const string& prefix = opts_.prefix; - if (prefix.empty()) return; - node_def->set_name(strings::StrCat(prefix, node_def->name())); + if (opts_.prefix.empty()) return; + node_def->set_name(strings::StrCat(opts_.prefix, node_def->name())); // Update names of input nodes for (int i = 0; i < node_def->input_size(); ++i) { StringPiece input(node_def->input(i)); // Skip remapped inputs (which already exist in g_ and are not being - // imported) + // imported). if (input_already_exists[i]) continue; if (input.Consume("^")) { - node_def->set_input(i, strings::StrCat("^", prefix, input)); + node_def->set_input(i, strings::StrCat("^", opts_.prefix, input)); } else { - node_def->set_input(i, strings::StrCat(prefix, input)); + node_def->set_input(i, strings::StrCat(opts_.prefix, input)); } } // Update names of colocation groups @@ -685,12 +710,62 @@ void GraphConstructor::AddPrefixToNodeDef( for (int i = 0; i < list->s_size(); ++i) { StringPiece v(list->s(i)); if (v.Consume(kColocationGroupPrefix)) { - list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix, v)); + list->set_s(i, + strings::StrCat(kColocationGroupPrefix, opts_.prefix, v)); } } } } +void GraphConstructor::UniquifyNames( + const std::vector<bool>& input_already_exists, NodeDef* node_def) { + if (NameExists(node_def->name())) { + string old_name = node_def->name(); + node_def->set_name(FindUniqueName(node_def->name())); + uniquified_names_[old_name] = node_def->name(); + } + for (int i = 0; i < node_def->input_size(); ++i) { + // Skip remapped inputs (which already exist in g_ and are not being + // imported). + if (input_already_exists[i]) continue; + TensorId id = ParseTensorName(node_def->input(i)); + // We require that UniquifyNames() is called on all NodeDefs in topological + // order. This guarantees that node_def's inputs will already be uniquified + // if necessary. + auto iter = uniquified_names_.find(id.first.ToString()); + if (iter == uniquified_names_.end()) continue; + id.first = iter->second; + node_def->set_input(i, id.ToString()); + } + // Update names of colocation groups + if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) { + auto* list = + node_def->mutable_attr()->at(kColocationAttrName).mutable_list(); + for (int i = 0; i < list->s_size(); ++i) { + StringPiece v(list->s(i)); + if (v.Consume(kColocationGroupPrefix)) { + auto iter = uniquified_names_.find(v.ToString()); + if (iter == uniquified_names_.end()) continue; + list->set_s(i, strings::StrCat(kColocationGroupPrefix, iter->second)); + } + } + } +} + +bool GraphConstructor::NameExists(StringPiece name) { + if (existing_nodes_.find(name) != existing_nodes_.end()) return true; + return existing_prefixes_.find(name) != existing_prefixes_.end(); +} + +string GraphConstructor::FindUniqueName(StringPiece original_name) { + string name = original_name.ToString(); + int count = 1; + while (NameExists(name)) { + name = strings::StrCat(original_name, "_", count++); + } + return name; +} + Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped) { const OpDef* op_def; @@ -825,7 +900,11 @@ Status GraphConstructor::Convert() { Node* node; if (opts_.importing) { - AddPrefixToNodeDef(input_already_exists, &imported_node_def); + if (!opts_.prefix.empty()) { + AddPrefixToNodeDef(input_already_exists, &imported_node_def); + } else if (opts_.uniquify_names) { + UniquifyNames(input_already_exists, &imported_node_def); + } TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def)); } TF_RETURN_IF_ERROR(MakeNode(*node_def, &node)); |