aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-23 06:03:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 06:10:39 -0700
commitdc13a8e2f7cfd56121347f5596f8b5a770da41c9 (patch)
treee9e4deac60e2951d5dda16ae0fa1eac1e0b8bd02
parenteea089bdb66597c9e66180d39b94eea2c17be93e (diff)
Fix import of meta graphs with partitioned variables into a scope.
Saver inspects SliceInfo to decide the variable name when creating a checkpoint. Before this fix even if a partitioned variable ("weights") was imported into a scope "a" it would still be checkpointed as ("weights") instead of ("a/weights") since import_scoped_meta_graph was not adjusting the SliceInfo. WARNING: if you use import_meta_graph on graphs with partitioned_variables WITH an import_scope argument AND then create a Saver to write/read checkpoints this change may break your checkpoint loading. PiperOrigin-RevId: 173105796
-rw-r--r--tensorflow/python/framework/meta_graph_test.py39
-rw-r--r--tensorflow/python/ops/variables.py3
2 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 65abb69599..06cee46bf6 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -36,8 +36,10 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -657,5 +659,42 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
initializer = variables.local_variables_initializer()
+class ExportImportAcrossScopesTest(test.TestCase):
+
+ def testPartionedVariables(self):
+ def make_graph_with_partitioned_variables():
+ variable_scope.get_variable(
+ name="weights",
+ partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0),
+ initializer=random_ops.truncated_normal([100, 10]))
+ self._testExportImportAcrossScopes(make_graph_with_partitioned_variables)
+
+ def _testExportImportAcrossScopes(self, graph_fn):
+ """Tests export and importing a graph across scopes.
+
+ Args:
+ graph_fn: A closure that creates a graph on the current scope.
+ """
+ with ops.Graph().as_default() as original_graph:
+ with variable_scope.variable_scope("dropA/dropB/keepA"):
+ graph_fn()
+ exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
+ graph=original_graph,
+ export_scope="dropA/dropB")[0]
+
+ with ops.Graph().as_default() as imported_graph:
+ meta_graph.import_scoped_meta_graph(
+ exported_meta_graph_def,
+ import_scope="importA")
+
+ with ops.Graph().as_default() as expected_graph:
+ with variable_scope.variable_scope("importA/keepA"):
+ graph_fn()
+
+ result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
+ expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
+ self.assertProtoEquals(expected, result)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 90b4f25d81..0272f77176 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -394,7 +394,8 @@ class Variable(object):
import_scope=import_scope))
if variable_def.HasField("save_slice_info_def"):
self._save_slice_info = Variable.SaveSliceInfo(
- save_slice_info_def=variable_def.save_slice_info_def)
+ save_slice_info_def=variable_def.save_slice_info_def,
+ import_scope=import_scope)
else:
self._save_slice_info = None
self._caching_device = None