diff options
author | 2017-10-23 06:03:28 -0700 | |
---|---|---|
committer | 2017-10-23 06:10:39 -0700 | |
commit | dc13a8e2f7cfd56121347f5596f8b5a770da41c9 (patch) | |
tree | e9e4deac60e2951d5dda16ae0fa1eac1e0b8bd02 | |
parent | eea089bdb66597c9e66180d39b94eea2c17be93e (diff) |
Fix import of meta graphs with partitioned variables into a scope.
Saver inspects SliceInfo to decide the variable name when creating a
checkpoint. Before this fix even if a partitioned variable ("weights")
was imported into a scope "a" it would still be checkpointed as ("weights")
instead of ("a/weights") since import_scoped_meta_graph was not adjusting
the SliceInfo.
WARNING: if you use import_meta_graph on graphs with partitioned_variables WITH an import_scope argument AND then create a Saver to write/read checkpoints this change
may break your checkpoint loading.
PiperOrigin-RevId: 173105796
-rw-r--r-- | tensorflow/python/framework/meta_graph_test.py | 39 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 3 |
2 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 65abb69599..06cee46bf6 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -36,8 +36,10 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -657,5 +659,42 @@ class MetaGraphWithVariableScopeTest(test.TestCase): initializer = variables.local_variables_initializer() +class ExportImportAcrossScopesTest(test.TestCase): + + def testPartionedVariables(self): + def make_graph_with_partitioned_variables(): + variable_scope.get_variable( + name="weights", + partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0), + initializer=random_ops.truncated_normal([100, 10])) + self._testExportImportAcrossScopes(make_graph_with_partitioned_variables) + + def _testExportImportAcrossScopes(self, graph_fn): + """Tests export and importing a graph across scopes. + + Args: + graph_fn: A closure that creates a graph on the current scope. + """ + with ops.Graph().as_default() as original_graph: + with variable_scope.variable_scope("dropA/dropB/keepA"): + graph_fn() + exported_meta_graph_def = meta_graph.export_scoped_meta_graph( + graph=original_graph, + export_scope="dropA/dropB")[0] + + with ops.Graph().as_default() as imported_graph: + meta_graph.import_scoped_meta_graph( + exported_meta_graph_def, + import_scope="importA") + + with ops.Graph().as_default() as expected_graph: + with variable_scope.variable_scope("importA/keepA"): + graph_fn() + + result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0] + expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0] + self.assertProtoEquals(expected, result) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 90b4f25d81..0272f77176 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -394,7 +394,8 @@ class Variable(object): import_scope=import_scope)) if variable_def.HasField("save_slice_info_def"): self._save_slice_info = Variable.SaveSliceInfo( - save_slice_info_def=variable_def.save_slice_info_def) + save_slice_info_def=variable_def.save_slice_info_def, + import_scope=import_scope) else: self._save_slice_info = None self._caching_device = None |