diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-20 09:48:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-20 09:56:23 -0700 |
commit | b23df6e50255991b06cc8d6596b1075c3bd3f7e9 (patch) | |
tree | 7d5ad2a5c106e17d8743c90bcd600d07fc3cf75c | |
parent | d44142b807bba47464d2a873e2dfcd641236591e (diff) |
Automated rollback of commit 91fd2cd6c3466340d3a69be76993e357662b2009
PiperOrigin-RevId: 209433774
-rw-r--r-- | tensorflow/python/kernel_tests/partitioned_variables_test.py | 35 | ||||
-rw-r--r-- | tensorflow/python/training/saver.py | 16 |
2 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index ba9359d923..1d0c2dceba 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -34,6 +36,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent +from tensorflow.python.training import saver as saver_lib class PartitionerCreatorsTest(test.TestCase): @@ -622,6 +625,38 @@ class PartitionedVariablesTestCase(test.TestCase): variables.global_variables_initializer().run() self.assertAllClose([-0.4, -0.4], x.eval()) + def testMetaGraphSaveLoad(self): + save_prefix = os.path.join(self.get_temp_dir(), "ckpt") + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0) + with variable_scope.variable_scope("root", partitioner=partitioner): + v0 = variable_scope.get_variable( + "v0", dtype=dtypes.float32, shape=(10, 10)) + v0_list = v0._get_variable_list() + v0_part = v0._get_partitions() + self.assertEqual(len(v0_list), 5) + self.assertAllEqual(v0_part, (5, 1)) + variables.global_variables_initializer().run() + + save_graph.get_collection_ref("partvar").append(v0) + saver = saver_lib.Saver() + save_graph.finalize() + save_path = saver.save(sess=session, save_path=save_prefix) + previous_value = session.run( + save_graph.get_tensor_by_name(v0.name + ":0")) + + restore_graph = ops.Graph() + with restore_graph.as_default(), self.test_session( + graph=restore_graph) as session: + saver = saver_lib.import_meta_graph(save_path + ".meta") + saver.restore(sess=session, save_path=save_path) + v0, = save_graph.get_collection_ref("partvar") + self.assertIsInstance(v0, variables.PartitionedVariable) + self.assertAllEqual( + previous_value, + session.run(restore_graph.get_tensor_by_name(v0.name + ":0"))) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index e35ea81456..274c856686 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -809,6 +809,22 @@ class BaseSaverBuilder(object): keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, version=self._write_version) else: + graph = ops.get_default_graph() + # Do some sanity checking on collections containing + # PartitionedVariables. If a saved collection has a PartitionedVariable, + # the GraphDef needs to include concat ops to get the value (or there'll + # be a lookup error on load). + check_collection_list = graph.get_all_collection_keys() + for collection_type in check_collection_list: + for element in graph.get_collection(collection_type): + if isinstance(element, variables.PartitionedVariable): + try: + graph.get_operation_by_name(element.name) + except KeyError: + # Create a concat op for this PartitionedVariable. The user may + # not need it, but we'll try looking it up on MetaGraph restore + # since it's in a collection. + element.as_tensor() return saver_pb2.SaverDef( filename_tensor_name=filename_tensor.name, save_tensor_name=save_tensor.name, |