diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-02-09 17:14:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-09 17:18:33 -0800 |
commit | 5b71a126c4f0733eccefee76b599e0315f052bef (patch) | |
tree | 488b7b1111f7fcbfd99c3ac2898bf82c203b5847 | |
parent | 816f59e6ab53c4553b0325b872b7be5ea73da89b (diff) |
import_graph_def: support "absolute" names with the C API enabled.
Passing a name with a trailing '/' to import_graph_def causes that
name to be used as-is (i.e. it is not appended to the existing name
scope and not de-duped with any existing name scopes. This is in order
to re-use an existing name scope). This didn't work with the C API
enabled because it was set to always have the C API uniquify the
prefix.
The fix is to not uniquify the prefix, since calling name_scope in
import_graph_def already has the logic to uniquify the prefix if
necessary. I'm not sure why I thought we needed the C API to do this
to being with.
In addition, this changes the graph_constructor.cc logic to uniquify
names if the prefix cannot be guaranteed unique (see the new test case
in graph_constructor_test.cc for why/when this is necessary).
PiperOrigin-RevId: 185215326
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 22 | ||||
-rw-r--r-- | tensorflow/python/framework/importer.py | 1 | ||||
-rw-r--r-- | tensorflow/python/framework/importer_test.py | 19 |
4 files changed, 45 insertions, 13 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 2a52c7516e..0629ff32d0 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -374,15 +374,8 @@ Status GraphConstructor::EnsureNoNameCollisions() { return errors::InvalidArgument("Imported node name prefix '", prefix_, "' would lead to invalid node names"); } - if (NameExistsInGraph(prefix_no_slash)) { - if (opts_.uniquify_prefix) { - prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); - } else { - return errors::InvalidArgument("Import node name prefix '", - prefix_no_slash, - "' conflicts with " - "name already used in the graph"); - } + if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) { + prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); } } return Status::OK(); @@ -990,7 +983,10 @@ Status GraphConstructor::Convert() { if (opts_.importing) { if (!prefix_.empty()) { AddPrefixToNodeDef(input_already_exists, &imported_node_def); - } else if (opts_.uniquify_names) { + } + // Note: no need to uniquify names if the prefix already guarantees + // uniqueness + if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) { UniquifyNames(input_already_exists, &imported_node_def); } TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def)); diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index c59e478f15..963c1dc024 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1834,7 +1834,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { EXPECT_EQ(results.return_nodes[1]->name(), "B_2"); EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0"); - // Import with an already-used prefix + // Import with an already-used prefix and uniquify_prefix = true opts.prefix = "A"; opts.uniquify_prefix = true; results = ImportGraphDefResults(); @@ -1846,9 +1846,27 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A"); // Create B_3 node to keep the A/B numbering in sync - opts = ImportGraphDefOptions(); ExpectOK("node { name: 'B_3' op: 'TestInput' }"); + // Import with an already-used prefix and uniquify_prefix = false + opts.uniquify_prefix = false; + results = ImportGraphDefResults(); + ExpectOK(graph_def_str, opts, &refiner, &results); + + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_nodes[0]->name(), "A/A"); + EXPECT_EQ(results.return_nodes[1]->name(), "A/B"); + EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A"); + + // Repeat the same import + results = ImportGraphDefResults(); + ExpectOK(graph_def_str, opts, &refiner, &results); + + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_nodes[0]->name(), "A/A_1"); + EXPECT_EQ(results.return_nodes[1]->name(), "A/B_1"); + EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A_1:0"); + // Import with existing de-duped node names opts = ImportGraphDefOptions(); opts.uniquify_names = true; diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index cc8f2392ba..6ecc1a40ae 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -270,7 +270,6 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map, """Populates the TF_ImportGraphDefOptions `options`.""" c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix) c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True) - c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True) for input_src, input_dst in input_map.items(): input_src = compat.as_str(input_src) diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index acaec37f81..bf5d9fe093 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -154,6 +154,25 @@ class ImportGraphDefTest(test.TestCase): self.assertEqual(b3.name, "A_3/B") self.assertEqual(list(b3.inputs), [a3.outputs[0]]) + # Import with an already-used name but with a '/' to indicate an + # "absolute" name scope (see the Graph.name_scope docstring). + a_a, a_b = importer.import_graph_def( + graph_def, + return_elements=["A", "B"], + name="A/") + self.assertEqual(a_a.name, "A/A") + self.assertEqual(a_b.name, "A/B") + self.assertEqual(list(a_b.inputs), [a_a.outputs[0]]) + + # Repeat the same import. + a_a1, a_b1 = importer.import_graph_def( + graph_def, + return_elements=["A", "B"], + name="A/") + self.assertEqual(a_a1.name, "A/A_1") + self.assertEqual(a_b1.name, "A/B_1") + self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]]) + # Import with existing de-duped node names a1_1, b1_1 = importer.import_graph_def( self._MakeGraphDef(""" |