aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-20 09:48:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 09:56:23 -0700
commitb23df6e50255991b06cc8d6596b1075c3bd3f7e9 (patch)
tree7d5ad2a5c106e17d8743c90bcd600d07fc3cf75c
parentd44142b807bba47464d2a873e2dfcd641236591e (diff)
Automated rollback of commit 91fd2cd6c3466340d3a69be76993e357662b2009
PiperOrigin-RevId: 209433774
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py35
-rw-r--r--tensorflow/python/training/saver.py16
2 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index ba9359d923..1d0c2dceba 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -34,6 +36,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import saver as saver_lib
class PartitionerCreatorsTest(test.TestCase):
@@ -622,6 +625,38 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose([-0.4, -0.4], x.eval())
+ def testMetaGraphSaveLoad(self):
+ save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
+ with variable_scope.variable_scope("root", partitioner=partitioner):
+ v0 = variable_scope.get_variable(
+ "v0", dtype=dtypes.float32, shape=(10, 10))
+ v0_list = v0._get_variable_list()
+ v0_part = v0._get_partitions()
+ self.assertEqual(len(v0_list), 5)
+ self.assertAllEqual(v0_part, (5, 1))
+ variables.global_variables_initializer().run()
+
+ save_graph.get_collection_ref("partvar").append(v0)
+ saver = saver_lib.Saver()
+ save_graph.finalize()
+ save_path = saver.save(sess=session, save_path=save_prefix)
+ previous_value = session.run(
+ save_graph.get_tensor_by_name(v0.name + ":0"))
+
+ restore_graph = ops.Graph()
+ with restore_graph.as_default(), self.test_session(
+ graph=restore_graph) as session:
+ saver = saver_lib.import_meta_graph(save_path + ".meta")
+ saver.restore(sess=session, save_path=save_path)
+ v0, = save_graph.get_collection_ref("partvar")
+ self.assertIsInstance(v0, variables.PartitionedVariable)
+ self.assertAllEqual(
+ previous_value,
+ session.run(restore_graph.get_tensor_by_name(v0.name + ":0")))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index e35ea81456..274c856686 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -809,6 +809,22 @@ class BaseSaverBuilder(object):
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
else:
+ graph = ops.get_default_graph()
+ # Do some sanity checking on collections containing
+ # PartitionedVariables. If a saved collection has a PartitionedVariable,
+ # the GraphDef needs to include concat ops to get the value (or there'll
+ # be a lookup error on load).
+ check_collection_list = graph.get_all_collection_keys()
+ for collection_type in check_collection_list:
+ for element in graph.get_collection(collection_type):
+ if isinstance(element, variables.PartitionedVariable):
+ try:
+ graph.get_operation_by_name(element.name)
+ except KeyError:
+ # Create a concat op for this PartitionedVariable. The user may
+ # not need it, but we'll try looking it up on MetaGraph restore
+ # since it's in a collection.
+ element.as_tensor()
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,