aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_constructor.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-11-03 18:31:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 18:35:31 -0700
commit11c8cc337eb5e41c1695e4d6f4e8b25cdd4d9545 (patch)
treec793a925186aff7cb96131bb0a5c44bed89f5996 /tensorflow/core/graph/graph_constructor.cc
parent7b2b720704ee802fd468334f640304e9036cf76c (diff)
Add uniquify_names option to ImportGraphDef.
This option allows ImportGraphDef to mimic the behavior of the Python import_graph_def function, which automatically creates unique node names instead of raising an exception (this is due to the Python op construction logic, not import_graph_def directly). This change is a steps towards switching import_graph_def to use the C API version. PiperOrigin-RevId: 174541334
Diffstat (limited to 'tensorflow/core/graph/graph_constructor.cc')
-rw-r--r--tensorflow/core/graph/graph_constructor.cc121
1 files changed, 100 insertions, 21 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 8fe4f535fb..753cb260e5 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -75,6 +75,7 @@ class GraphConstructor {
prefix(in.prefix.empty() || StringPiece(in.prefix).ends_with("/")
? in.prefix
: in.prefix + "/"),
+ uniquify_names(in.uniquify_names),
input_map(in.input_map),
skip_mapped_nodes(in.skip_mapped_nodes),
control_dependencies(in.control_dependencies),
@@ -86,6 +87,7 @@ class GraphConstructor {
bool expect_device_spec;
string prefix;
+ bool uniquify_names;
std::map<TensorId, TensorId> input_map;
bool skip_mapped_nodes;
std::vector<string> control_dependencies;
@@ -190,6 +192,20 @@ class GraphConstructor {
void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
NodeDef* node_def);
+ // Modifies `node_def` if its name isn't unique, or if any of its inputs'
+ // names have been uniquified. This must be called in topological order on all
+ // nodes.
+ void UniquifyNames(const std::vector<bool>& input_already_exists,
+ NodeDef* node_def);
+
+ // Returns true if `name` already exists in `g_` (either as a node name or
+ // prefix).
+ bool NameExists(StringPiece name);
+
+ // Returns a unique version of `original_name`, or `original_name` if it's
+ // already unique in the graph.
+ string FindUniqueName(StringPiece original_name);
+
// From constructor
const Options opts_;
const NodeDefSlice node_defs_;
@@ -224,9 +240,16 @@ class GraphConstructor {
// alternative implementation of std::unordered_map.
std::unordered_map<StringPiece, NodeInfo, StringPiece::Hasher> gdef_nodes_;
- // Mapping from node name to the existing node in g_
+ // Mapping from node name to the existing node in g_.
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> existing_nodes_;
+ // Prefixes already used in the graph.
+ std::unordered_set<StringPiece, StringPiece::Hasher> existing_prefixes_;
+
+ // Imported node names that have been uniquified. The key is the original
+ // name, the value is the new unique name.
+ std::unordered_map<string, string> uniquified_names_;
+
// Index of NodeDefs in node_defs_ with all inputs already converted.
std::vector<int> ready_;
@@ -281,6 +304,7 @@ bool NodeNameInValues(const std::vector<string>& control_dependencies,
Status GraphConstructor::EnsureNoNameCollisions() {
existing_nodes_.reserve(g_->num_nodes());
+ // Populate existing_nodes_ and existing_prefixes_.
for (Node* n : g_->nodes()) {
bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
if (already_exists) {
@@ -296,18 +320,22 @@ Status GraphConstructor::EnsureNoNameCollisions() {
n->name(), "'");
}
}
+ // Add all of node's prefixes to existing_prefixes_ (if it has any).
+ size_t idx = -1;
+ while ((idx = n->name().find('/', idx + 1)) != string::npos) {
+ StringPiece name(n->name());
+ existing_prefixes_.insert(name.substr(0, idx));
+ }
}
- if (opts_.prefix.empty() && opts_.importing) {
+ if (opts_.prefix.empty() && opts_.importing && !opts_.uniquify_names) {
for (const NodeDef* n : node_defs_) {
const string& name = n->name();
- if (existing_nodes_.find(name) != existing_nodes_.end()) {
- return errors::InvalidArgument("Node '", name,
+ if (NameExists(name)) {
+ return errors::InvalidArgument("Node name '", name,
"' already exists in the Graph");
}
}
} else if (!opts_.prefix.empty()) {
- // Importing nodes with a prefix. No nodes should exist with the same
- // prefix.
StringPiece prefix_no_slash(opts_.prefix);
prefix_no_slash.remove_suffix(1);
if (!IsValidNodeName(prefix_no_slash, false)) {
@@ -315,13 +343,11 @@ Status GraphConstructor::EnsureNoNameCollisions() {
opts_.prefix,
"' would lead to invalid node names");
}
- for (const Node* n : g_->nodes()) {
- if (StringPiece(n->name()).starts_with(opts_.prefix)) {
- return errors::InvalidArgument(
- "Import node name prefix conflicts with names of nodes already in "
- "the Graph, such as '",
- n->name(), "'");
- }
+ if (NameExists(prefix_no_slash)) {
+ return errors::InvalidArgument("Import node name prefix '",
+ prefix_no_slash,
+ "' conflicts with "
+ "name already used in the graph");
}
}
return Status::OK();
@@ -663,19 +689,18 @@ void GraphConstructor::AddControlDependencies(
void GraphConstructor::AddPrefixToNodeDef(
const std::vector<bool>& input_already_exists, NodeDef* node_def) {
- const string& prefix = opts_.prefix;
- if (prefix.empty()) return;
- node_def->set_name(strings::StrCat(prefix, node_def->name()));
+ if (opts_.prefix.empty()) return;
+ node_def->set_name(strings::StrCat(opts_.prefix, node_def->name()));
// Update names of input nodes
for (int i = 0; i < node_def->input_size(); ++i) {
StringPiece input(node_def->input(i));
// Skip remapped inputs (which already exist in g_ and are not being
- // imported)
+ // imported).
if (input_already_exists[i]) continue;
if (input.Consume("^")) {
- node_def->set_input(i, strings::StrCat("^", prefix, input));
+ node_def->set_input(i, strings::StrCat("^", opts_.prefix, input));
} else {
- node_def->set_input(i, strings::StrCat(prefix, input));
+ node_def->set_input(i, strings::StrCat(opts_.prefix, input));
}
}
// Update names of colocation groups
@@ -685,12 +710,62 @@ void GraphConstructor::AddPrefixToNodeDef(
for (int i = 0; i < list->s_size(); ++i) {
StringPiece v(list->s(i));
if (v.Consume(kColocationGroupPrefix)) {
- list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix, v));
+ list->set_s(i,
+ strings::StrCat(kColocationGroupPrefix, opts_.prefix, v));
}
}
}
}
+void GraphConstructor::UniquifyNames(
+ const std::vector<bool>& input_already_exists, NodeDef* node_def) {
+ if (NameExists(node_def->name())) {
+ string old_name = node_def->name();
+ node_def->set_name(FindUniqueName(node_def->name()));
+ uniquified_names_[old_name] = node_def->name();
+ }
+ for (int i = 0; i < node_def->input_size(); ++i) {
+ // Skip remapped inputs (which already exist in g_ and are not being
+ // imported).
+ if (input_already_exists[i]) continue;
+ TensorId id = ParseTensorName(node_def->input(i));
+ // We require that UniquifyNames() is called on all NodeDefs in topological
+ // order. This guarantees that node_def's inputs will already be uniquified
+ // if necessary.
+ auto iter = uniquified_names_.find(id.first.ToString());
+ if (iter == uniquified_names_.end()) continue;
+ id.first = iter->second;
+ node_def->set_input(i, id.ToString());
+ }
+ // Update names of colocation groups
+ if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
+ auto* list =
+ node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
+ for (int i = 0; i < list->s_size(); ++i) {
+ StringPiece v(list->s(i));
+ if (v.Consume(kColocationGroupPrefix)) {
+ auto iter = uniquified_names_.find(v.ToString());
+ if (iter == uniquified_names_.end()) continue;
+ list->set_s(i, strings::StrCat(kColocationGroupPrefix, iter->second));
+ }
+ }
+ }
+}
+
+bool GraphConstructor::NameExists(StringPiece name) {
+ if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
+ return existing_prefixes_.find(name) != existing_prefixes_.end();
+}
+
+string GraphConstructor::FindUniqueName(StringPiece original_name) {
+ string name = original_name.ToString();
+ int count = 1;
+ while (NameExists(name)) {
+ name = strings::StrCat(original_name, "_", count++);
+ }
+ return name;
+}
+
Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
bool* is_node_mapped) {
const OpDef* op_def;
@@ -825,7 +900,11 @@ Status GraphConstructor::Convert() {
Node* node;
if (opts_.importing) {
- AddPrefixToNodeDef(input_already_exists, &imported_node_def);
+ if (!opts_.prefix.empty()) {
+ AddPrefixToNodeDef(input_already_exists, &imported_node_def);
+ } else if (opts_.uniquify_names) {
+ UniquifyNames(input_already_exists, &imported_node_def);
+ }
TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
}
TF_RETURN_IF_ERROR(MakeNode(*node_def, &node));